Merge branch 'master' into master

This commit is contained in:
Jun Zhang
2025-07-09 17:20:40 +08:00
committed by GitHub
35 changed files with 412 additions and 106 deletions

View File

@@ -18,6 +18,7 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_DETAIL_LIMIT;
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
@@ -65,12 +66,23 @@ public class SemanticParseInfo implements Serializable {
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
if (difference == 0) {
if (Math.abs(difference) < 0.0005) { // 看完全匹配的个数,实践证明,可以用户输入规范后,该逻辑具有优势
if (!o1.getDataSetId().equals(o2.getDataSetId())) {
List<SchemaElementMatch> elementMatches1 = o1.getElementMatches().stream()
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
List<SchemaElementMatch> elementMatches2 = o2.getElementMatches().stream()
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
if (elementMatches1.size() > elementMatches2.size()) {
return -1;
} else if (elementMatches1.size() < elementMatches2.size()) {
return 1;
}
}
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
if (difference == 0) {
if (Math.abs(difference) < 0.0005) {
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
}
if (difference == 0) {
if (Math.abs(difference) < 0.0005) {
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
}
}

View File

@@ -16,4 +16,7 @@ public class SqlInfo implements Serializable {
// SQL to be executed finally
private String querySQL;
// Physical SQL corrected by LLM for performance optimization
private String correctedQuerySQL;
}

View File

@@ -8,5 +8,6 @@ public enum ChatWorkflowState {
VALIDATING,
SQL_CORRECTING,
PROCESSING,
PHYSICAL_SQL_CORRECTING,
FINISHED
}

View File

@@ -0,0 +1,98 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.service.AiServices;
import lombok.Data;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* 物理SQL修正器 - 使用LLM优化物理SQL性能
*/
@Slf4j
public class LLMPhysicalSqlCorrector extends BaseSemanticCorrector {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY = "PHYSICAL_SQL_CORRECTOR";
private static final String INSTRUCTION = ""
+ "#Role: You are a senior database performance optimization expert experienced in SQL tuning."
+ "\n\n#Task: You will be provided with a user question and the corresponding physical SQL query,"
+ " please analyze and optimize this SQL to improve query performance." + "\n\n#Rules:"
+ "\n1. DO NOT add or introduce any new fields, columns, or aliases that are not in the original SQL."
+ "\n2. Push WHERE conditions into JOIN ON clauses when possible to reduce intermediate result sets."
+ "\n3. Optimize JOIN order by placing smaller tables or tables with selective conditions first."
+ "\n4. For date range conditions, ensure they are applied as early as possible in the query execution."
+ "\n5. Remove or comment out database-specific index hints (like USE INDEX) that may cause syntax errors."
+ "\n6. ONLY modify the structure and order of existing elements, do not change field names or add new ones."
+ "\n7. Ensure the optimized SQL is syntactically correct and logically equivalent to the original."
+ "\n\n#Question: {{question}}" + "\n\n#OriginalSQL: {{sql}}";
public LLMPhysicalSqlCorrector() {
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("物理SQL修正")
.appModule(AppModule.CHAT).description("通过大模型对物理SQL做性能优化").enable(false).build());
}
@Data
@ToString
static class PhysicalSql {
@Description("either positive or negative")
private String opinion;
@Description("optimized sql if negative")
private String sql;
}
interface PhysicalSqlExtractor {
PhysicalSql generatePhysicalSql(String text);
}
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
ChatApp chatApp = chatQueryContext.getRequest().getChatAppConfig().get(APP_KEY);
if (!chatQueryContext.getRequest().getText2SQLType().enableLLM() || Objects.isNull(chatApp)
|| !chatApp.isEnable()) {
return;
}
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
PhysicalSqlExtractor extractor =
AiServices.create(PhysicalSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
semanticParseInfo, chatApp.getPrompt());
PhysicalSql physicalSql =
extractor.generatePhysicalSql(prompt.toUserMessage().singleText());
keyPipelineLog.info("LLMPhysicalSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
physicalSql);
if ("NEGATIVE".equalsIgnoreCase(physicalSql.getOpinion())
&& StringUtils.isNotBlank(physicalSql.getSql())) {
semanticParseInfo.getSqlInfo().setCorrectedQuerySQL(physicalSql.getSql());
}
}
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
String promptTemplate) {
Map<String, Object> variable = new HashMap<>();
variable.put("question", queryText);
variable.put("sql", semanticParseInfo.getSqlInfo().getQuerySQL());
return PromptTemplate.from(promptTemplate).apply(variable);
}
}

View File

@@ -14,10 +14,8 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.*;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
@@ -51,13 +49,33 @@ public class PromptHelper {
// use random collection of exemplars for each self-consistency inference
for (int i = 0; i < selfConsistencyNumber; i++) {
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
// only shuffle the exemplars from config
List<Text2SQLExemplar> subList =
shuffledList.subList(llmReq.getDynamicExemplars().size(), shuffledList.size());
Collections.shuffle(subList);
results.add(shuffledList.subList(0, Math.min(shuffledList.size(), fewShotNumber)));
List<Text2SQLExemplar> same = shuffledList.stream() // 相似度极高的话,先找出来
.filter(e -> e.getSimilarity() > 0.989).collect(Collectors.toList());
List<Text2SQLExemplar> noSame = shuffledList.stream()
.filter(e -> e.getSimilarity() <= 0.989).collect(Collectors.toList());
if ((noSame.size() - same.size()) > fewShotNumber) {// 去除部分最低分
noSame.sort(Comparator.comparingDouble(Text2SQLExemplar::getSimilarity));
noSame = noSame.subList((noSame.size() - fewShotNumber) / 2, noSame.size());
}
Text2SQLExemplar mostSimilar = noSame.get(noSame.size() - 1);
Collections.shuffle(noSame);
List<Text2SQLExemplar> ts;
if (same.size() > 0) {// 一样的话,必须作为提示语
ts = new ArrayList<>();
int needSize = Math.min(noSame.size() + same.size(), fewShotNumber);
if (needSize > same.size()) {
ts.addAll(noSame.subList(0, needSize - same.size()));
}
ts.addAll(same);
} else { // 至少要一个最像的
ts = noSame.subList(0, Math.min(noSame.size(), fewShotNumber));
if (!ts.contains(mostSimilar)) {
ts.remove(ts.size() - 1);
ts.add(mostSimilar);
}
}
results.add(ts);
}
return results;
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.core.pojo;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
import lombok.Data;
@@ -24,6 +25,7 @@ public class QueryStatement {
private SemanticSchemaResp semanticSchema;
private Integer limit = 1000;
private Boolean isTranslated = false;
private User user;
public boolean isOk() {
return StringUtils.isBlank(errMsg) && StringUtils.isNotBlank(sql);

View File

@@ -102,7 +102,7 @@ public class DimValueAspect {
continue;
}
for (DimensionResp dimension : dimensions) {
if (!expression.getFieldName().equals(dimension.getName())
if (!expression.getFieldName().equals(dimension.getBizName())
|| CollectionUtils.isEmpty(dimension.getDimValueMaps())) {
continue;
}
@@ -124,6 +124,7 @@ public class DimValueAspect {
sql = SqlReplaceHelper.replaceValue(sql, filedNameToValueMap);
log.debug("correctorSql after replacing:{}", sql);
querySqlReq.setSql(sql);
querySqlReq.getSqlInfo().setQuerySQL(sql);
Map<String, Map<String, String>> techNameToBizName = getTechNameToBizName(dimensions);
SemanticQueryResp queryResultWithColumns = (SemanticQueryResp) joinPoint.proceed();

View File

@@ -296,6 +296,9 @@ public class S2SemanticLayerService implements SemanticLayerService {
queryStatement.setSql(semanticQueryReq.getSqlInfo().getQuerySQL());
queryStatement.setIsTranslated(true);
}
if (queryStatement != null) {
queryStatement.setUser(user);
}
return queryStatement;
}

View File

@@ -83,13 +83,10 @@ public class DimensionRepositoryImpl implements DimensionRepository {
}
if (StringUtils.isNotBlank(dimensionFilter.getKey())) {
String key = dimensionFilter.getKey();
queryWrapper.lambda()
.and(wrapper -> wrapper
.like(DimensionDO::getName, key).or()
.like(DimensionDO::getBizName, key).or().like(DimensionDO::getDescription, key)
.or().like(DimensionDO::getAlias, key).or()
.like(DimensionDO::getCreatedBy, key)
);
queryWrapper.and(qw->qw.lambda().like(DimensionDO::getName, key).or()
.like(DimensionDO::getBizName, key).or().like(DimensionDO::getDescription, key)
.or().like(DimensionDO::getAlias, key).or()
.like(DimensionDO::getCreatedBy, key));
}
return dimensionDOMapper.selectList(queryWrapper);

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.LLMPhysicalSqlCorrector;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
@@ -76,6 +77,10 @@ public class ChatWorkflowEngine {
long start = System.currentTimeMillis();
performTranslating(queryCtx, parseResult);
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
queryCtx.setChatWorkflowState(ChatWorkflowState.PHYSICAL_SQL_CORRECTING);
break;
case PHYSICAL_SQL_CORRECTING:
performPhysicalSqlCorrecting(queryCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
break;
default:
@@ -162,4 +167,26 @@ public class ChatWorkflowEngine {
parseResult.setErrorMsg(String.join("\n", errorMsg));
}
}
private void performPhysicalSqlCorrecting(ChatQueryContext queryCtx) {
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
for (SemanticCorrector corrector : semanticCorrectors) {
if (corrector instanceof LLMPhysicalSqlCorrector) {
corrector.correct(queryCtx, semanticQuery.getParseInfo());
// 如果物理SQL被修正了更新querySQL为修正后的版本
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedQuerySQL())) {
parseInfo.getSqlInfo()
.setQuerySQL(parseInfo.getSqlInfo().getCorrectedQuerySQL());
log.info("Physical SQL corrected and updated querySQL: {}",
parseInfo.getSqlInfo().getQuerySQL());
}
break;
}
}
}
}
}
}

View File

@@ -19,9 +19,10 @@ public class DataSetSchemaBuilder {
public static DataSetSchema build(DataSetSchemaResp resp) {
DataSetSchema dataSetSchema = new DataSetSchema();
dataSetSchema.setQueryConfig(resp.getQueryConfig());
SchemaElement dataSet = SchemaElement.builder().dataSetId(resp.getId())
.dataSetName(resp.getName()).id(resp.getId()).name(resp.getName())
.bizName(resp.getBizName()).type(SchemaElementType.DATASET).build();
SchemaElement dataSet =
SchemaElement.builder().dataSetId(resp.getId()).dataSetName(resp.getName())
.id(resp.getId()).name(resp.getName()).description(resp.getDescription())
.bizName(resp.getBizName()).type(SchemaElementType.DATASET).build();
dataSetSchema.setDataSet(dataSet);
dataSetSchema.setDatabaseType(resp.getDatabaseType());
dataSetSchema.setDatabaseVersion(resp.getDatabaseVersion());

View File

@@ -138,7 +138,8 @@ public class DictUtils {
semanticQueryReq.setNeedAuth(false);
String bizName = dictItemResp.getBizName();
try {
SemanticQueryResp semanticQueryResp = queryService.queryByReq(semanticQueryReq, null);
SemanticQueryResp semanticQueryResp =
queryService.queryByReq(semanticQueryReq, User.getDefaultUser());
if (Objects.isNull(semanticQueryResp)
|| CollectionUtils.isEmpty(semanticQueryResp.getResultList())) {
return lines;
@@ -274,6 +275,9 @@ public class DictUtils {
private QuerySqlReq constructQuerySqlReq(DictItemResp dictItemResp) {
ModelResp model = modelService.getModel(dictItemResp.getModelId());
String tableStr = StringUtils.isNotBlank(model.getModelDetail().getTableQuery())
? model.getModelDetail().getTableQuery()
: "(" + model.getModelDetail().getSqlQuery() + ")";
String sqlPattern =
"select %s,count(1) from %s %s group by %s order by count(1) desc limit %d";
String dimBizName = dictItemResp.getBizName();
@@ -287,8 +291,7 @@ public class DictUtils {
limit = Integer.MAX_VALUE;
}
String sql =
String.format(sqlPattern, dimBizName, model.getBizName(), where, dimBizName, limit);
String sql = String.format(sqlPattern, dimBizName, tableStr, where, dimBizName, limit);
Set<Long> modelIds = new HashSet<>();
modelIds.add(dictItemResp.getModelId());
QuerySqlReq querySqlReq = new QuerySqlReq();

View File

@@ -109,7 +109,8 @@ public class QueryUtils {
column.setModelId(metric.getModelId());
}
// if column nameEn contains metric alias, use metric dataFormatType
if (column.getDataFormatType() == null && metric.getAlias() != null) {
if (column.getDataFormatType() == null
&& StringUtils.isNotEmpty(metric.getAlias())) {
for (String alias : metric.getAlias().split(",")) {
if (nameEn.contains(alias)) {
column.setDataFormatType(metric.getDataFormatType());