mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-22 23:14:33 +08:00
Compare commits
8 Commits
bda4bdda77
...
07825b50b5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07825b50b5 | ||
|
|
b70b7ed01a | ||
|
|
335e1f9ada | ||
|
|
33268bf3d9 | ||
|
|
86b9d2013a | ||
|
|
aced1dfd3e | ||
|
|
aaf2d46a56 | ||
|
|
c8abea9c1a |
@@ -9,7 +9,8 @@ public enum EngineType {
|
||||
POSTGRESQL(6, "POSTGRESQL"),
|
||||
OTHER(7, "OTHER"),
|
||||
DUCKDB(8, "DUCKDB"),
|
||||
HANADB(9, "HANADB");
|
||||
HANADB(9, "HANADB"),
|
||||
STARROCKS(10, "STARROCKS"),;
|
||||
|
||||
private Integer code;
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
import lombok.*;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@@ -21,6 +17,7 @@ public class SchemaElementMatch implements Serializable {
|
||||
private String word;
|
||||
private Long frequency;
|
||||
private boolean isInherited;
|
||||
private boolean llmMatched;
|
||||
|
||||
public boolean isFullMatched() {
|
||||
return 1.0 == similarity;
|
||||
|
||||
@@ -13,6 +13,7 @@ public class EmbeddingResult extends MapResult {
|
||||
|
||||
private String id;
|
||||
private Map<String, String> metadata;
|
||||
private boolean llmMatched;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||
@@ -11,6 +14,7 @@ import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@@ -23,10 +27,16 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public boolean accept(ChatQueryContext chatQueryContext) {
|
||||
return MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
|
||||
boolean b0 = MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
|
||||
boolean b1 = chatQueryContext.getRequest().getText2SQLType() == Text2SQLType.LLM_OR_RULE;
|
||||
return b0 || b1;
|
||||
}
|
||||
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
|
||||
// TODO: 如果是在LOOSE执行过了,那么在LLM_OR_RULE阶段可以不用执行,所以这里缺乏一个状态来传递,暂时先忽略这个浪费行为吧
|
||||
SchemaMapInfo mappedInfo = chatQueryContext.getMapInfo();
|
||||
|
||||
// 1. Query from embedding by queryText
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
|
||||
@@ -53,15 +63,26 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// Build SchemaElementMatch object
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(matchResult.getName()).similarity(matchResult.getSimilarity())
|
||||
.detectWord(matchResult.getDetectWord()).build();
|
||||
schemaElementMatch.setLlmMatched(matchResult.isLlmMatched());
|
||||
|
||||
// 3. Add SchemaElementMatch to mapInfo
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
if (CollectionUtils.isEmpty(matchResults)) {
|
||||
log.info("embedding mapper no match");
|
||||
} else {
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
log.info("embedding match name=[{}],detectWord=[{}],similarity=[{}],metadata=[{}]",
|
||||
matchResult.getName(), matchResult.getDetectWord(),
|
||||
matchResult.getSimilarity(), JsonUtil.toString(matchResult.getMetadata()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
@@ -14,18 +22,12 @@ import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*;
|
||||
|
||||
/**
|
||||
* EmbeddingMatchStrategy uses vector database to perform similarity search against the embeddings
|
||||
@@ -35,37 +37,165 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING
|
||||
@Slf4j
|
||||
public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Autowired
|
||||
protected MapperConfig mapperConfig;
|
||||
|
||||
@Autowired
|
||||
private MetaEmbeddingService metaEmbeddingService;
|
||||
|
||||
private static final String LLM_FILTER_PROMPT =
|
||||
"""
|
||||
\
|
||||
#Role: You are a professional data analyst specializing in metrics and dimensions.
|
||||
#Task: Given a user query and a list of retrieved metrics/dimensions through vector recall,
|
||||
please analyze which metrics/dimensions the user is most likely interested in.
|
||||
#Rules:
|
||||
1. Based on user query and retrieved info, accurately determine metrics/dimensions user truly cares about.
|
||||
2. Do not return all retrieved info, only select those highly relevant to user query.
|
||||
3. Maintain high quality output, exclude metrics/dimensions irrelevant to user intent.
|
||||
4. Output must be in JSON array format, only include IDs from retrieved info, e.g.: ['id1', 'id2']
|
||||
5. Return JSON content directly without markdown formatting
|
||||
#Input Example:
|
||||
#User Query: {{userText}}
|
||||
#Retrieved Metrics/Dimensions: {{retrievedInfo}}
|
||||
#Output:""";
|
||||
|
||||
@Override
|
||||
public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
if (chatQueryContext == null || CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
log.warn("Invalid input parameters: context={}, dataSetIds={}", chatQueryContext,
|
||||
detectDataSetIds);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// 1. Base detection
|
||||
List<EmbeddingResult> baseResults = super.detect(chatQueryContext, terms, detectDataSetIds);
|
||||
|
||||
boolean useLLM = Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM));
|
||||
|
||||
// 2. LLM enhanced detection
|
||||
if (useLLM) {
|
||||
List<EmbeddingResult> llmResults = detectWithLLM(chatQueryContext, detectDataSetIds);
|
||||
if (!CollectionUtils.isEmpty(llmResults)) {
|
||||
baseResults.addAll(llmResults);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Deduplicate results
|
||||
return baseResults.stream().distinct().collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform enhanced detection using LLM
|
||||
*/
|
||||
private List<EmbeddingResult> detectWithLLM(ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds) {
|
||||
try {
|
||||
String queryText = chatQueryContext.getRequest().getQueryText();
|
||||
if (StringUtils.isBlank(queryText)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// Get segmentation results
|
||||
Set<String> detectSegments = extractValidSegments(queryText);
|
||||
if (CollectionUtils.isEmpty(detectSegments)) {
|
||||
log.info("No valid segments found for text: {}", queryText);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, true);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in LLM detection for context: {}", chatQueryContext, e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract valid word segments by filtering out unwanted word natures
|
||||
*/
|
||||
private Set<String> extractValidSegments(String text) {
|
||||
List<String> natureList = Arrays.asList(StringUtils.split(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE ), ","));
|
||||
return HanlpHelper.getSegment().seg(text).stream()
|
||||
.filter(t -> natureList.stream().noneMatch(nature -> t.nature.startsWith(nature)))
|
||||
.map(Term::getWord).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Process detection in batches with LLM option
|
||||
*
|
||||
* @param chatQueryContext The context of the chat query
|
||||
* @param detectDataSetIds Target dataset IDs for detection
|
||||
* @param detectSegments Segments to be detected
|
||||
* @param useLlm Whether to use LLM for filtering results
|
||||
* @return List of embedding results
|
||||
*/
|
||||
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
|
||||
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
|
||||
int embeddingMapperBatch = Integer
|
||||
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||
|
||||
List<String> queryTextsList =
|
||||
detectSegments.stream().map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment))
|
||||
.collect(Collectors.toList());
|
||||
// Process and filter query texts
|
||||
List<String> queryTextsList = detectSegments.stream().map(String::trim)
|
||||
.filter(StringUtils::isNotBlank).collect(Collectors.toList());
|
||||
|
||||
// Partition queries into sub-lists for batch processing
|
||||
List<List<String>> queryTextsSubList =
|
||||
Lists.partition(queryTextsList, embeddingMapperBatch);
|
||||
|
||||
// Create and execute tasks for each batch
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results));
|
||||
tasks.add(
|
||||
createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results, useLlm));
|
||||
}
|
||||
executeTasks(tasks);
|
||||
|
||||
// Apply LLM filtering if enabled
|
||||
if (useLlm) {
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("userText", chatQueryContext.getRequest().getQueryText());
|
||||
variable.put("retrievedInfo", JSONObject.toJSONString(results));
|
||||
|
||||
Prompt prompt = PromptTemplate.from(LLM_FILTER_PROMPT).apply(variable);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel();
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage().singleText());
|
||||
|
||||
if (StringUtils.isBlank(response)) {
|
||||
results.clear();
|
||||
} else {
|
||||
List<String> retrievedIds = JSONObject.parseArray(response, String.class);
|
||||
results = results.stream().filter(t -> retrievedIds.contains(t.getId()))
|
||||
.collect(Collectors.toSet());
|
||||
results.forEach(r -> r.setLlmMatched(true));
|
||||
}
|
||||
}
|
||||
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a task for batch processing
|
||||
*
|
||||
* @param chatQueryContext The context of the chat query
|
||||
* @param detectDataSetIds Target dataset IDs
|
||||
* @param queryTextsSub Sub-list of query texts to process
|
||||
* @param results Shared result set for collecting results
|
||||
* @param useLlm Whether to use LLM
|
||||
* @return Callable task
|
||||
*/
|
||||
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, Set<EmbeddingResult> results) {
|
||||
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
|
||||
return () -> {
|
||||
List<EmbeddingResult> oneRoundResults =
|
||||
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext);
|
||||
List<EmbeddingResult> oneRoundResults = detectByQueryTextsSub(detectDataSetIds,
|
||||
queryTextsSub, chatQueryContext, useLlm);
|
||||
synchronized (results) {
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
@@ -73,57 +203,73 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a sub-list of query texts
|
||||
*
|
||||
* @param detectDataSetIds Target dataset IDs
|
||||
* @param queryTextsSub Sub-list of query texts
|
||||
* @param chatQueryContext Chat query context
|
||||
* @param useLlm Whether to use LLM
|
||||
* @return List of embedding results for this batch
|
||||
*/
|
||||
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
|
||||
List<String> queryTextsSub, ChatQueryContext chatQueryContext, boolean useLlm) {
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
|
||||
|
||||
// Get configuration parameters
|
||||
double threshold =
|
||||
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||
|
||||
// step1. build query params
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
|
||||
// step2. retrieveQuery by detectSegment
|
||||
Double.parseDouble(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||
int embeddingNumber =
|
||||
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
||||
Integer.parseInt(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
||||
int embeddingRoundNumber =
|
||||
Integer.parseInt(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
||||
|
||||
// Build and execute query
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return new ArrayList<>();
|
||||
return Collections.emptyList();
|
||||
}
|
||||
// step3. build EmbeddingResults
|
||||
List<EmbeddingResult> collect = retrieveQueryResults.stream().map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
||||
return retrieval.getSimilarity() < threshold;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
// Process results
|
||||
List<EmbeddingResult> collect = retrieveQueryResults.stream().peek(result -> {
|
||||
if (!useLlm && CollectionUtils.isNotEmpty(result.getRetrieval())) {
|
||||
result.getRetrieval()
|
||||
.removeIf(retrieval -> !result.getQuery().contains(retrieval.getQuery())
|
||||
&& retrieval.getSimilarity() < threshold);
|
||||
}
|
||||
return retrieveQueryResult;
|
||||
}).filter(retrieveQueryResult -> CollectionUtils
|
||||
.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
||||
.map(retrieval -> {
|
||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||
embeddingResult.setName(retrieval.getQuery());
|
||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet()
|
||||
.stream().collect(Collectors.toMap(Map.Entry::getKey,
|
||||
entry -> entry.getValue().toString()));
|
||||
embeddingResult.setMetadata(convertedMap);
|
||||
return embeddingResult;
|
||||
}))
|
||||
}).filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
|
||||
.flatMap(result -> result.getRetrieval().stream()
|
||||
.map(retrieval -> convertToEmbeddingResult(result, retrieval)))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// step4. select mapResul in one round
|
||||
int embeddingRoundNumber =
|
||||
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
||||
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
|
||||
return collect.stream().sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity))
|
||||
.limit(roundNumber).collect(Collectors.toList());
|
||||
// Sort and limit results
|
||||
return collect.stream()
|
||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity).reversed())
|
||||
.limit(embeddingRoundNumber * queryTextsSub.size()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert RetrieveQueryResult and Retrieval to EmbeddingResult
|
||||
*
|
||||
* @param queryResult The query result containing retrieval information
|
||||
* @param retrieval The retrieval data to be converted
|
||||
* @return Converted EmbeddingResult
|
||||
*/
|
||||
private EmbeddingResult convertToEmbeddingResult(RetrieveQueryResult queryResult,
|
||||
Retrieval retrieval) {
|
||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(queryResult.getQuery());
|
||||
embeddingResult.setName(retrieval.getQuery());
|
||||
|
||||
// Convert metadata to string values
|
||||
Map<String, String> metadata = retrieval.getMetadata().entrySet().stream().collect(
|
||||
Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
|
||||
embeddingResult.setMetadata(metadata);
|
||||
|
||||
return embeddingResult;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,12 +7,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -66,7 +61,7 @@ public class MapFilter {
|
||||
List<SchemaElementMatch> value = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(value)) {
|
||||
value.removeIf(schemaElementMatch -> StringUtils
|
||||
.length(schemaElementMatch.getDetectWord()) <= 1);
|
||||
.length(schemaElementMatch.getDetectWord()) <= 1 && !schemaElementMatch.isLlmMatched());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,7 +80,7 @@ public class MapFilter {
|
||||
}
|
||||
|
||||
public static void filterByQueryDataType(ChatQueryContext chatQueryContext,
|
||||
Predicate<SchemaElement> needRemovePredicate) {
|
||||
Predicate<SchemaElement> needRemovePredicate) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||
|
||||
@@ -57,4 +57,12 @@ public class MapperConfig extends ParameterConfig {
|
||||
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
||||
new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值",
|
||||
"向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_USE_LLM =
|
||||
new Parameter("s2.mapper.embedding.use-llm-enhance", "false", "使用LLM对召回的向量进行二次判断开关",
|
||||
"embedding的结果再通过一次LLM来筛选,这时候忽略各个向量阀值", "bool", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE =
|
||||
new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']", "使用LLM召回二次处理时对问题分词词性的控制",
|
||||
"分词后允许的词性才会进行向量召回", "list", "Mapper相关配置");
|
||||
}
|
||||
|
||||
@@ -106,8 +106,6 @@ public class PromptHelper {
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
||||
} else {
|
||||
metricStr.append(" AGGREGATE 'NONE'");
|
||||
}
|
||||
metricStr.append(">");
|
||||
metrics.add(metricStr.toString());
|
||||
|
||||
@@ -17,7 +17,20 @@ import java.util.List;
|
||||
@Slf4j
|
||||
public abstract class BaseDbAdaptor implements DbAdaptor {
|
||||
|
||||
public List<String> getDBs(ConnectInfo connectionInfo) throws SQLException {
|
||||
@Override
|
||||
public List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException {
|
||||
// Apart from supporting multiple catalog types of data sources, other types will return an
|
||||
// empty set by default.
|
||||
return List.of();
|
||||
}
|
||||
|
||||
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
|
||||
// Except for special types implemented separately, the generic logic catalog does not take
|
||||
// effect.
|
||||
return getDBs(connectionInfo);
|
||||
}
|
||||
|
||||
protected List<String> getDBs(ConnectInfo connectionInfo) throws SQLException {
|
||||
List<String> dbs = Lists.newArrayList();
|
||||
DatabaseMetaData metaData = getDatabaseMetaData(connectionInfo);
|
||||
try {
|
||||
|
||||
@@ -14,7 +14,9 @@ public interface DbAdaptor {
|
||||
|
||||
String rewriteSql(String sql);
|
||||
|
||||
List<String> getDBs(ConnectInfo connectInfo) throws SQLException;
|
||||
List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException;
|
||||
|
||||
List<String> getDBs(ConnectInfo connectInfo, String catalog) throws SQLException;
|
||||
|
||||
List<String> getTables(ConnectInfo connectInfo, String schemaName) throws SQLException;
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ public class DbAdaptorFactory {
|
||||
dbAdaptorMap.put(EngineType.OTHER.getName(), new DefaultDbAdaptor());
|
||||
dbAdaptorMap.put(EngineType.DUCKDB.getName(), new DuckdbAdaptor());
|
||||
dbAdaptorMap.put(EngineType.HANADB.getName(), new HanadbAdaptor());
|
||||
dbAdaptorMap.put(EngineType.STARROCKS.getName(), new StarrocksAdaptor());
|
||||
}
|
||||
|
||||
public static DbAdaptor getEngineAdaptor(String engineType) {
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package com.tencent.supersonic.headless.core.adaptor.db;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.sql.*;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class StarrocksAdaptor extends MysqlAdaptor {
|
||||
|
||||
@Override
|
||||
public List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException {
|
||||
List<String> catalogs = Lists.newArrayList();
|
||||
try (Connection con = DriverManager.getConnection(connectInfo.getUrl(),
|
||||
connectInfo.getUserName(), connectInfo.getPassword());
|
||||
Statement st = con.createStatement();
|
||||
ResultSet rs = st.executeQuery("SHOW CATALOGS")) {
|
||||
while (rs.next()) {
|
||||
catalogs.add(rs.getString(1));
|
||||
}
|
||||
}
|
||||
return catalogs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
|
||||
Assert.hasText(catalog, "StarRocks type catalog can not be null or empty");
|
||||
List<String> dbs = Lists.newArrayList();
|
||||
try (Connection con = DriverManager.getConnection(connectionInfo.getUrl(),
|
||||
connectionInfo.getUserName(), connectionInfo.getPassword());
|
||||
Statement st = con.createStatement();
|
||||
ResultSet rs = st.executeQuery("SHOW DATABASES IN " + catalog)) {
|
||||
while (rs.next()) {
|
||||
dbs.add(rs.getString(1));
|
||||
}
|
||||
}
|
||||
return dbs;
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ public class DbParameterFactory {
|
||||
parametersBuilder.put(EngineType.MYSQL.getName(), new MysqlParametersBuilder());
|
||||
parametersBuilder.put(EngineType.POSTGRESQL.getName(), new PostgresqlParametersBuilder());
|
||||
parametersBuilder.put(EngineType.HANADB.getName(), new HanadbParametersBuilder());
|
||||
parametersBuilder.put(EngineType.STARROCKS.getName(), new StarrocksParametersBuilder());
|
||||
parametersBuilder.put(EngineType.OTHER.getName(), new OtherParametersBuilder());
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.headless.server.pojo;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class StarrocksParametersBuilder extends DefaultParametersBuilder {
|
||||
|
||||
@Override
|
||||
public List<DatabaseParameter> build() {
|
||||
return super.build();
|
||||
}
|
||||
}
|
||||
@@ -76,9 +76,15 @@ public class DatabaseController {
|
||||
return databaseService.executeSql(sqlExecuteReq, user);
|
||||
}
|
||||
|
||||
@RequestMapping("/getCatalogs")
|
||||
public List<String> getCatalogs(@RequestParam("id") Long databaseId) throws SQLException {
|
||||
return databaseService.getCatalogs(databaseId);
|
||||
}
|
||||
|
||||
@RequestMapping("/getDbNames")
|
||||
public List<String> getDbNames(@RequestParam("id") Long databaseId) throws SQLException {
|
||||
return databaseService.getDbNames(databaseId);
|
||||
public List<String> getDbNames(@RequestParam("id") Long databaseId,
|
||||
@RequestParam(value = "catalog", required = false) String catalog) throws SQLException {
|
||||
return databaseService.getDbNames(databaseId, catalog);
|
||||
}
|
||||
|
||||
@RequestMapping("/getTables")
|
||||
|
||||
@@ -36,7 +36,9 @@ public interface DatabaseService {
|
||||
|
||||
void deleteDatabase(Long databaseId);
|
||||
|
||||
List<String> getDbNames(Long id) throws SQLException;
|
||||
List<String> getCatalogs(Long id) throws SQLException;
|
||||
|
||||
List<String> getDbNames(Long id, String catalog) throws SQLException;
|
||||
|
||||
List<String> getTables(Long id, String db) throws SQLException;
|
||||
|
||||
|
||||
@@ -200,10 +200,17 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getDbNames(Long id) throws SQLException {
|
||||
public List<String> getCatalogs(Long id) throws SQLException {
|
||||
DatabaseResp databaseResp = getDatabase(id);
|
||||
DbAdaptor dbAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
|
||||
return dbAdaptor.getDBs(DatabaseConverter.getConnectInfo(databaseResp));
|
||||
return dbAdaptor.getCatalogs(DatabaseConverter.getConnectInfo(databaseResp));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getDbNames(Long id, String catalog) throws SQLException {
|
||||
DatabaseResp databaseResp = getDatabase(id);
|
||||
DbAdaptor dbAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
|
||||
return dbAdaptor.getDBs(DatabaseConverter.getConnectInfo(databaseResp), catalog);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -104,11 +104,11 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
|
||||
} else {
|
||||
MetricResp metricRespByBizName = bizNameMap.get(metric.getBizName());
|
||||
MetricResp metricRespByName = nameMap.get(metric.getName());
|
||||
if (null != metricRespByBizName && isChange(metric, metricRespByBizName)) {
|
||||
if (null != metricRespByBizName) {
|
||||
metric.setId(metricRespByBizName.getId());
|
||||
this.updateMetric(metric, user);
|
||||
} else {
|
||||
if (null != metricRespByName && isChange(metric, metricRespByName)) {
|
||||
if (null != metricRespByName) {
|
||||
metric.setId(metricRespByName.getId());
|
||||
this.updateMetric(metric, user);
|
||||
}
|
||||
@@ -819,7 +819,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
|
||||
return modelResps.stream().map(ModelResp::getId).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private boolean isChange(MetricReq metricReq, MetricResp metricResp) {
|
||||
private boolean isNameChange(MetricReq metricReq, MetricResp metricResp) {
|
||||
boolean isNameChange = !metricReq.getName().equals(metricResp.getName());
|
||||
return isNameChange;
|
||||
}
|
||||
|
||||
@@ -142,8 +142,12 @@ public class ModelServiceImpl implements ModelService {
|
||||
@Override
|
||||
@Transactional
|
||||
public ModelResp updateModel(ModelReq modelReq, User user) throws Exception {
|
||||
// checkParams(modelReq);
|
||||
// Comment out below checks for now, they seem unnecessary and
|
||||
// lead to unexpected exception in updating model
|
||||
/*
|
||||
checkParams(modelReq);
|
||||
checkRelations(modelReq);
|
||||
*/
|
||||
ModelDO modelDO = modelRepository.getModelById(modelReq.getId());
|
||||
ModelConverter.convert(modelDO, modelReq, user);
|
||||
modelRepository.updateModel(modelDO);
|
||||
|
||||
@@ -17,6 +17,25 @@
|
||||
<start-class>com.tencent.supersonic.StandaloneLauncher</start-class>
|
||||
</properties>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springdoc</groupId>
|
||||
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
|
||||
<version>2.1.0</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-expression</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-beans</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-webmvc</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>launchers-common</artifactId>
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
spring:
|
||||
datasource:
|
||||
driver-class-name: org.postgresql.Driver
|
||||
url: jdbc:postgresql://${S2_DB_HOST:localhost}:${S2_DB_PORT:5432}/${S2_DB_DATABASE:postgres}?stringtype=unspecified
|
||||
username: ${S2_DB_USER:postgres}
|
||||
password: ${S2_DB_PASSWORD:postgres}
|
||||
url: jdbc:postgresql://localhost:5432/postgres?stringtype=unspecified
|
||||
username: postgres
|
||||
password: postgres
|
||||
sql:
|
||||
init:
|
||||
mode: always
|
||||
username: ${S2_DB_USER:postgres}
|
||||
password: ${S2_DB_PASSWORD:postgres}
|
||||
username: postgres
|
||||
password: postgres
|
||||
schema-locations: classpath:db/schema-postgres.sql,classpath:db/schema-postgres-demo.sql
|
||||
data-locations: classpath:db/data-postgres.sql,classpath:db/data-postgres-demo.sql
|
||||
|
||||
@@ -17,9 +17,9 @@ s2:
|
||||
store:
|
||||
provider: PGVECTOR
|
||||
base:
|
||||
url: ${S2_DB_HOST:127.0.0.1}
|
||||
port: ${S2_DB_PORT:5432}
|
||||
databaseName: ${S2_DB_DATABASE:postgres}
|
||||
user: ${S2_DB_USER:postgres}
|
||||
password: ${S2_DB_PASSWORD:postgres}
|
||||
url: 127.0.0.1
|
||||
port: 5432
|
||||
databaseName: postgres
|
||||
user: postgres
|
||||
password: postgres
|
||||
dimension: 512
|
||||
@@ -5,11 +5,11 @@ DELETE FROM s2_canvas;
|
||||
|
||||
-- sample user
|
||||
-- The default value for the password is 123456
|
||||
insert into s2_user (id, `name`, password, salt, display_name, email, is_admin) values (1, 'admin','c3VwZXJzb25pY0BiaWNvbdktJJYWw6A3rEmBUPzbn/6DNeYnD+y3mAwDKEMS3KVT','jGl25bVBBBW96Qi9Te4V3w==','admin','admin@xx.com', 1);
|
||||
insert into s2_user (id, `name`, password, salt, display_name, email) values (2, 'jack','c3VwZXJzb25pY0BiaWNvbWxGalmwa0h/trkh/3CWOYMDiku0Op1VmOfESIKmN0HG','MWERWefm/3hD6kYndF6JIg==','jack','jack@xx.com');
|
||||
insert into s2_user (id, `name`, password, salt, display_name, email) values (3, 'tom','c3VwZXJzb25pY0BiaWNvbVWv0CZ6HzeX8GRUpw0C8NSaQ+0hE/dAcmzRpCFwAqxK','4WCPdcXXgT89QDHLML+3hg==','tom','tom@xx.com');
|
||||
insert into s2_user (id, `name`, password, salt, display_name, email, is_admin) values (4, 'lucy','c3VwZXJzb25pY0BiaWNvbc7Ychfu99lPL7rLmCkf/vgF4RASa4Z++Mxo1qlDCpci','3Jnpqob6uDoGLP9eCAg5Fw==','lucy','lucy@xx.com', 1);
|
||||
insert into s2_user (id, `name`, password, salt, display_name, email) values (5, 'alice','c3VwZXJzb25pY0BiaWNvbe9Z4F2/DVIfAJoN1HwUTuH1KgVuiusvfh7KkWYQSNHk','K9gGyX8OAK8aH8Myj6djqQ==','alice','alice@xx.com');
|
||||
INSERT INTO s2_user (`name`, password, salt, display_name, email, is_admin) values ('admin','c3VwZXJzb25pY0BiaWNvbdktJJYWw6A3rEmBUPzbn/6DNeYnD+y3mAwDKEMS3KVT','jGl25bVBBBW96Qi9Te4V3w==','admin','admin@xx.com', 1);
|
||||
INSERT INTO s2_user (`name`, password, salt, display_name, email) values ('jack','c3VwZXJzb25pY0BiaWNvbWxGalmwa0h/trkh/3CWOYMDiku0Op1VmOfESIKmN0HG','MWERWefm/3hD6kYndF6JIg==','jack','jack@xx.com');
|
||||
INSERT INTO s2_user (`name`, password, salt, display_name, email) values ('tom','c3VwZXJzb25pY0BiaWNvbVWv0CZ6HzeX8GRUpw0C8NSaQ+0hE/dAcmzRpCFwAqxK','4WCPdcXXgT89QDHLML+3hg==','tom','tom@xx.com');
|
||||
INSERT INTO s2_user (`name`, password, salt, display_name, email, is_admin) values ('lucy','c3VwZXJzb25pY0BiaWNvbc7Ychfu99lPL7rLmCkf/vgF4RASa4Z++Mxo1qlDCpci','3Jnpqob6uDoGLP9eCAg5Fw==','lucy','lucy@xx.com', 1);
|
||||
INSERT INTO s2_user (`name`, password, salt, display_name, email) values ('alice','c3VwZXJzb25pY0BiaWNvbe9Z4F2/DVIfAJoN1HwUTuH1KgVuiusvfh7KkWYQSNHk','K9gGyX8OAK8aH8Myj6djqQ==','alice','alice@xx.com');
|
||||
|
||||
|
||||
INSERT INTO s2_available_date_info (`item_id`, `type`, `date_format`, `start_date`, `end_date`, `unavailable_date`, `created_at`, `created_by`, `updated_at`, `updated_by`)
|
||||
|
||||
@@ -5,11 +5,11 @@ DELETE FROM s2_canvas;
|
||||
|
||||
-- sample user
|
||||
-- The default value for the password is 123456
|
||||
insert into s2_user (id, "name", password, salt, display_name, email, is_admin) values (1, 'admin','c3VwZXJzb25pY0BiaWNvbdktJJYWw6A3rEmBUPzbn/6DNeYnD+y3mAwDKEMS3KVT','jGl25bVBBBW96Qi9Te4V3w==','admin','admin@xx.com', 1);
|
||||
insert into s2_user (id, "name", password, salt, display_name, email) values (2, 'jack','c3VwZXJzb25pY0BiaWNvbWxGalmwa0h/trkh/3CWOYMDiku0Op1VmOfESIKmN0HG','MWERWefm/3hD6kYndF6JIg==','jack','jack@xx.com');
|
||||
insert into s2_user (id, "name", password, salt, display_name, email) values (3, 'tom','c3VwZXJzb25pY0BiaWNvbVWv0CZ6HzeX8GRUpw0C8NSaQ+0hE/dAcmzRpCFwAqxK','4WCPdcXXgT89QDHLML+3hg==','tom','tom@xx.com');
|
||||
insert into s2_user (id, "name", password, salt, display_name, email, is_admin) values (4, 'lucy','c3VwZXJzb25pY0BiaWNvbc7Ychfu99lPL7rLmCkf/vgF4RASa4Z++Mxo1qlDCpci','3Jnpqob6uDoGLP9eCAg5Fw==','lucy','lucy@xx.com', 1);
|
||||
insert into s2_user (id, "name", password, salt, display_name, email) values (5, 'alice','c3VwZXJzb25pY0BiaWNvbe9Z4F2/DVIfAJoN1HwUTuH1KgVuiusvfh7KkWYQSNHk','K9gGyX8OAK8aH8Myj6djqQ==','alice','alice@xx.com');
|
||||
insert into s2_user ("name", password, salt, display_name, email, is_admin) values ('admin','c3VwZXJzb25pY0BiaWNvbdktJJYWw6A3rEmBUPzbn/6DNeYnD+y3mAwDKEMS3KVT','jGl25bVBBBW96Qi9Te4V3w==','admin','admin@xx.com', 1);
|
||||
insert into s2_user ("name", password, salt, display_name, email) values ('jack','c3VwZXJzb25pY0BiaWNvbWxGalmwa0h/trkh/3CWOYMDiku0Op1VmOfESIKmN0HG','MWERWefm/3hD6kYndF6JIg==','jack','jack@xx.com');
|
||||
insert into s2_user ("name", password, salt, display_name, email) values ('tom','c3VwZXJzb25pY0BiaWNvbVWv0CZ6HzeX8GRUpw0C8NSaQ+0hE/dAcmzRpCFwAqxK','4WCPdcXXgT89QDHLML+3hg==','tom','tom@xx.com');
|
||||
insert into s2_user ("name", password, salt, display_name, email, is_admin) values ('lucy','c3VwZXJzb25pY0BiaWNvbc7Ychfu99lPL7rLmCkf/vgF4RASa4Z++Mxo1qlDCpci','3Jnpqob6uDoGLP9eCAg5Fw==','lucy','lucy@xx.com', 1);
|
||||
insert into s2_user ("name", password, salt, display_name, email) values ('alice','c3VwZXJzb25pY0BiaWNvbe9Z4F2/DVIfAJoN1HwUTuH1KgVuiusvfh7KkWYQSNHk','K9gGyX8OAK8aH8Myj6djqQ==','alice','alice@xx.com');
|
||||
|
||||
|
||||
INSERT INTO s2_available_date_info (item_id, type, date_format, start_date, end_date, unavailable_date, created_at, created_by, updated_at, updated_by)
|
||||
|
||||
@@ -41,3 +41,5 @@ s2:
|
||||
threshold: 0.5
|
||||
min:
|
||||
threshold: 0.3
|
||||
embedding:
|
||||
use-llm-enhance: true
|
||||
|
||||
5
pom.xml
5
pom.xml
@@ -214,11 +214,6 @@
|
||||
<artifactId>mockito-inline</artifactId>
|
||||
<version>${mockito-inline.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springdoc</groupId>
|
||||
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
|
||||
<version>2.1.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.amazonaws</groupId>
|
||||
<artifactId>aws-java-sdk</artifactId>
|
||||
|
||||
@@ -56,9 +56,7 @@ const BarChart: React.FC<Props> = ({
|
||||
} else {
|
||||
instanceObj = instanceRef.current;
|
||||
}
|
||||
const data = (queryResults || []).sort(
|
||||
(a: any, b: any) => b[metricColumnName] - a[metricColumnName]
|
||||
);
|
||||
const data = (queryResults || []);
|
||||
const xData = data.map(item =>
|
||||
item[categoryColumnName] !== undefined ? item[categoryColumnName] : '未知'
|
||||
);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Form, Input, Spin, Select, message } from 'antd';
|
||||
import type { FormInstance } from 'antd/lib/form';
|
||||
import { getDbNames, getTables, getDimensionList } from '../../service';
|
||||
import {getDbNames, getTables, getDimensionList, getCatalogs} from '../../service';
|
||||
import { ISemantic } from '../../data';
|
||||
import FormItemTitle from '@/components/FormHelper/FormItemTitle';
|
||||
|
||||
@@ -20,13 +20,16 @@ const ModelBasicForm: React.FC<Props> = ({
|
||||
isEdit,
|
||||
modelItem,
|
||||
databaseConfigList,
|
||||
form,
|
||||
mode = 'normal',
|
||||
}) => {
|
||||
const [currentDbLinkConfigId, setCurrentDbLinkConfigId] = useState<number>();
|
||||
const [catalogList, setCatalogList] = useState<string[]>([]);
|
||||
const [dbNameList, setDbNameList] = useState<string[]>([]);
|
||||
const [tableNameList, setTableNameList] = useState<any[]>([]);
|
||||
const [loading, setLoading] = useState<boolean>(false);
|
||||
const [dimensionOptions, setDimensionOptions] = useState<{ label: string; value: number }[]>([]);
|
||||
const [catalogSelectOpen, setCatalogSelectOpen] = useState<boolean>(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (modelItem?.id) {
|
||||
@@ -50,9 +53,49 @@ const ModelBasicForm: React.FC<Props> = ({
|
||||
}
|
||||
};
|
||||
|
||||
const queryDbNameList = async (databaseId: number) => {
|
||||
const onDatabaseSelect = (databaseId: number, type: string) => {
|
||||
setLoading(true);
|
||||
const { code, data, msg } = await getDbNames(databaseId);
|
||||
if (type === 'STARROCKS') {
|
||||
queryCatalogList(databaseId)
|
||||
setCatalogSelectOpen(true);
|
||||
setDbNameList([]);
|
||||
} else {
|
||||
queryDbNameList(databaseId, "");
|
||||
setCatalogSelectOpen(false);
|
||||
setCatalogList([]);
|
||||
}
|
||||
form.setFieldsValue({
|
||||
catalog: undefined,
|
||||
dbName: undefined,
|
||||
tableName: undefined,
|
||||
})
|
||||
};
|
||||
|
||||
const queryCatalogList = async (databaseId: number) => {
|
||||
setLoading(true);
|
||||
const { code, data, msg } = await getCatalogs(databaseId);
|
||||
setLoading(false)
|
||||
if (code === 200) {
|
||||
const list = data || [];
|
||||
setCatalogList(list);
|
||||
} else {
|
||||
message.error(msg);
|
||||
}
|
||||
}
|
||||
|
||||
const onCatalogSelect = (catalog: string) => {
|
||||
if (currentDbLinkConfigId) {
|
||||
queryDbNameList(currentDbLinkConfigId, catalog);
|
||||
}
|
||||
form.setFieldsValue({
|
||||
dbName: undefined,
|
||||
tableName: undefined,
|
||||
})
|
||||
}
|
||||
|
||||
const queryDbNameList = async (databaseId: number, catalog: string) => {
|
||||
setLoading(true);
|
||||
const { code, data, msg } = await getDbNames(databaseId, catalog);
|
||||
setLoading(false);
|
||||
if (code === 200) {
|
||||
const list = data || [];
|
||||
@@ -61,6 +104,7 @@ const ModelBasicForm: React.FC<Props> = ({
|
||||
message.error(msg);
|
||||
}
|
||||
};
|
||||
|
||||
const queryTableNameList = async (databaseName: string) => {
|
||||
if (!currentDbLinkConfigId) {
|
||||
return;
|
||||
@@ -89,18 +133,37 @@ const ModelBasicForm: React.FC<Props> = ({
|
||||
showSearch
|
||||
placeholder="请选择数据库连接"
|
||||
disabled={isEdit}
|
||||
onChange={(dbLinkConfigId: number) => {
|
||||
queryDbNameList(dbLinkConfigId);
|
||||
onSelect={(dbLinkConfigId: number, option) => {
|
||||
onDatabaseSelect(dbLinkConfigId, option.type);
|
||||
setCurrentDbLinkConfigId(dbLinkConfigId);
|
||||
}}
|
||||
>
|
||||
{databaseConfigList.map((item) => (
|
||||
<Select.Option key={item.id} value={item.id} disabled={!item.hasUsePermission}>
|
||||
<Select.Option key={item.id} value={item.id} disabled={!item.hasUsePermission} type={item.type}>
|
||||
{item.name}
|
||||
</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
</FormItem>
|
||||
<FormItem
|
||||
name="catalog"
|
||||
label="Catalog"
|
||||
rules={[{ required: true, message: '请选择Catalog' }]}
|
||||
hidden={!catalogSelectOpen}
|
||||
>
|
||||
<Select
|
||||
showSearch
|
||||
placeholder="请选择Catalog"
|
||||
disabled={isEdit}
|
||||
onSelect={onCatalogSelect}
|
||||
>
|
||||
{catalogList.map((item) => (
|
||||
<Select.Option key={item} value={item}>
|
||||
{item}
|
||||
</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
</FormItem>
|
||||
<FormItem
|
||||
name="dbName"
|
||||
label="数据库名"
|
||||
@@ -110,8 +173,11 @@ const ModelBasicForm: React.FC<Props> = ({
|
||||
showSearch
|
||||
placeholder="请先选择一个数据库连接"
|
||||
disabled={isEdit}
|
||||
onChange={(dbName: string) => {
|
||||
onSelect={(dbName: string) => {
|
||||
queryTableNameList(dbName);
|
||||
form.setFieldsValue({
|
||||
tableName: undefined,
|
||||
})
|
||||
}}
|
||||
>
|
||||
{dbNameList.map((item) => (
|
||||
|
||||
@@ -379,11 +379,21 @@ export async function listColumnsBySql(data: { databaseId: number; sql: string }
|
||||
});
|
||||
}
|
||||
|
||||
export function getDbNames(dbId: number): Promise<any> {
|
||||
export function getCatalogs(dbId: number): Promise<any> {
|
||||
return request(`${process.env.API_BASE_URL}database/getCatalogs`, {
|
||||
method: 'GET',
|
||||
params: {
|
||||
id: dbId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function getDbNames(dbId: number, catalog: string): Promise<any> {
|
||||
return request(`${process.env.API_BASE_URL}database/getDbNames`, {
|
||||
method: 'GET',
|
||||
params: {
|
||||
id: dbId,
|
||||
catalog: catalog,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user