mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless)Remove unnecessary TranslateSqlReq, use SemanticQueryReq instead.
This commit is contained in:
@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatExecutor {
|
||||
.agentId(chatExecuteContext.getAgentId())
|
||||
.status(MemoryStatus.PENDING)
|
||||
.question(chatExecuteContext.getQueryText())
|
||||
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
|
||||
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getParsedS2SQL())
|
||||
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
|
||||
.createdBy(chatExecuteContext.getUser().getName())
|
||||
.updatedBy(chatExecuteContext.getUser().getName())
|
||||
@@ -64,12 +64,12 @@ public class SqlExecutor implements ChatExecutor {
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
if (Objects.isNull(parseInfo.getSqlInfo())
|
||||
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
QuerySqlReq sqlReq = QuerySqlReq.builder()
|
||||
.sql(parseInfo.getSqlInfo().getCorrectS2SQL())
|
||||
.sql(parseInfo.getSqlInfo().getCorrectedS2SQL())
|
||||
.build();
|
||||
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||
|
||||
@@ -161,7 +161,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
|
||||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
||||
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
|
||||
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL();
|
||||
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
||||
.curtQuestion(currentMapResult.getQueryText())
|
||||
.histQuestion(lastParseResult.getQueryText())
|
||||
|
||||
@@ -5,8 +5,13 @@ import lombok.Data;
|
||||
@Data
|
||||
public class SqlInfo {
|
||||
|
||||
private String s2SQL;
|
||||
private String correctS2SQL;
|
||||
// S2SQL generated by semantic parsers
|
||||
private String parsedS2SQL;
|
||||
|
||||
// S2SQL corrected by semantic correctors
|
||||
private String correctedS2SQL;
|
||||
|
||||
// SQL to be executed finally
|
||||
private String querySQL;
|
||||
private String sourceId;
|
||||
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
result.setDataSetId(this.getDataSetId());
|
||||
result.setModelIds(this.getModelIdSet());
|
||||
result.setParams(new ArrayList<>());
|
||||
result.getSqlInfo().setCorrectS2SQL(sql);
|
||||
result.getSqlInfo().setCorrectedS2SQL(sql);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class TranslateSqlReq<T> {
|
||||
|
||||
private QueryMethod queryTypeEnum;
|
||||
|
||||
private T queryReq;
|
||||
}
|
||||
@@ -13,10 +13,12 @@ import java.io.Serializable;
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class TranslateResp implements Serializable {
|
||||
public class SemanticTranslateResp implements Serializable {
|
||||
|
||||
private String querySQL;
|
||||
|
||||
private String sql;
|
||||
private boolean isOk;
|
||||
|
||||
private String errMsg;
|
||||
|
||||
}
|
||||
@@ -22,7 +22,7 @@ public class AggCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
return;
|
||||
}
|
||||
doCorrect(chatQueryContext, semanticParseInfo);
|
||||
@@ -74,7 +74,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
|
||||
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
|
||||
|
||||
@@ -98,7 +98,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
|
||||
@@ -34,8 +34,8 @@ public class GrammarCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
// check has distinct
|
||||
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
|
||||
@@ -68,7 +68,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||
@@ -83,7 +83,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -49,19 +49,19 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -50,21 +50,21 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectedS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectedS2SQL());
|
||||
sqlInfo.setCorrectedS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
@@ -79,8 +79,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
|
||||
@@ -111,14 +111,14 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereExpressionList)) {
|
||||
return;
|
||||
@@ -143,7 +143,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
|
||||
|
||||
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
@@ -44,7 +44,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(querySql);
|
||||
}
|
||||
|
||||
protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
|
||||
@@ -65,7 +65,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
needAddFields.removeAll(selectFields);
|
||||
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
|
||||
return addFieldsToSelectSql;
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
//decide whether remove date field from where
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorDate = environment.getProperty("s2.corrector.date");
|
||||
@@ -55,12 +55,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
|
||||
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
|
||||
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
|
||||
//decide whether add date field to where
|
||||
@@ -88,11 +88,11 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
}
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||
if (Objects.isNull(dateBoundInfo)) {
|
||||
return;
|
||||
@@ -108,14 +108,14 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
|
||||
@@ -50,7 +50,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,9 +71,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
|
||||
@@ -46,14 +46,14 @@ public class QueryTypeParser implements SemanticParser {
|
||||
private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
//1. entity queryType
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
|
||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
||||
@@ -63,7 +63,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.ID;
|
||||
}
|
||||
}
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
|
||||
selectFields.addAll(whereFields);
|
||||
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
|
||||
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
||||
@@ -91,7 +91,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
|
||||
@@ -38,7 +38,7 @@ public class LLMResponseService {
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
@@ -83,8 +83,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(semanticSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -33,6 +33,6 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL());
|
||||
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,8 +152,8 @@ public class QueryReqBuilder {
|
||||
|
||||
public static QuerySqlReq buildS2SQLReq(SqlInfo sqlInfo, Long dataSetId) {
|
||||
QuerySqlReq querySQLReq = new QuerySqlReq();
|
||||
if (Objects.nonNull(sqlInfo.getCorrectS2SQL())) {
|
||||
querySQLReq.setSql(sqlInfo.getCorrectS2SQL());
|
||||
if (Objects.nonNull(sqlInfo.getCorrectedS2SQL())) {
|
||||
querySQLReq.setSql(sqlInfo.getCorrectedS2SQL());
|
||||
}
|
||||
querySQLReq.setSqlInfo(sqlInfo);
|
||||
querySQLReq.setDataSetId(dataSetId);
|
||||
|
||||
@@ -30,14 +30,14 @@ class AggCorrectorTest {
|
||||
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
|
||||
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
|
||||
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
|
||||
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
|
||||
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||
|
||||
@@ -65,8 +65,8 @@ class SchemaCorrectorTest {
|
||||
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
@@ -80,7 +80,7 @@ class SchemaCorrectorTest {
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
|
||||
parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
@@ -92,11 +92,11 @@ class SchemaCorrectorTest {
|
||||
parseResult.setLinkingValues(linkingValues);
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setParsedS2SQL(sql);
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -44,12 +44,12 @@ class SelectCorrectorTest {
|
||||
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||
|
||||
@@ -19,84 +19,84 @@ class TimeCorrectorTest {
|
||||
//1.数据日期 <=
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//2.数据日期 <
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//3.数据日期 >=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//4.数据日期 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//5.no 数据日期
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//6. 数据日期-月 <=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
|
||||
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//7. 数据日期-月 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//8. no where
|
||||
sql = "SELECT COUNT(1) FROM 数据库";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
|
||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectedS2SQL());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ class WhereCorrectorTest {
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
@@ -54,7 +54,7 @@ class WhereCorrectorTest {
|
||||
WhereCorrector whereCorrector = new WhereCorrector();
|
||||
whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo);
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
|
||||
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "
|
||||
|
||||
@@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
|
||||
@@ -73,9 +72,6 @@ public class S2DataPermissionAspect {
|
||||
SemanticQueryReq queryReq = null;
|
||||
if (objects[0] instanceof SemanticQueryReq) {
|
||||
queryReq = (SemanticQueryReq) objects[0];
|
||||
} else if (objects[0] instanceof TranslateSqlReq) {
|
||||
queryReq = (SemanticQueryReq) ((TranslateSqlReq<?>) objects[0]).getQueryReq();
|
||||
needQueryData = false;
|
||||
}
|
||||
if (queryReq == null) {
|
||||
throw new InvalidArgumentException("queryReq is not Invalid");
|
||||
|
||||
@@ -4,10 +4,9 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
|
||||
@@ -20,12 +19,12 @@ public interface SemanticLayerService {
|
||||
|
||||
DataSetSchema getDataSetSchema(Long id);
|
||||
|
||||
SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception;
|
||||
|
||||
SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception;
|
||||
|
||||
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||
|
||||
<T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception;
|
||||
|
||||
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user);
|
||||
|
||||
List<ItemResp> getDomainDataSetTree();
|
||||
|
||||
@@ -26,11 +26,9 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.CostType;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
@@ -41,7 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -247,8 +245,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
List<String> fields = new ArrayList<>();
|
||||
if (Objects.nonNull(parseInfo.getSqlInfo())
|
||||
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
fields = SqlSelectHelper.getAllFields(correctorSql);
|
||||
}
|
||||
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
|
||||
@@ -260,13 +258,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
|
||||
log.info("llm begin revise filters!");
|
||||
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
|
||||
.queryTypeEnum(QueryMethod.SQL).build();
|
||||
TranslateResp explain = semanticLayerService.translate(translateSqlReq, user);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
} else {
|
||||
log.info("rule begin replace metrics and revise filters!");
|
||||
//remove unvalid filters
|
||||
@@ -303,7 +299,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
|
||||
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
|
||||
@@ -334,7 +330,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||
List<String> oriMetrics = parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("before replaceMetrics:{}", correctorSql);
|
||||
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
|
||||
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||
@@ -343,7 +339,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||
}
|
||||
log.info("after replaceMetrics:{}", correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
}
|
||||
|
||||
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||
@@ -598,7 +594,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
public void correct(QuerySqlReq querySqlReq, User user) {
|
||||
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
|
||||
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -613,8 +609,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
queryCtx.setSemanticSchema(semanticSchema);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
|
||||
sqlInfo.setS2SQL(querySqlReq.getSql());
|
||||
sqlInfo.setCorrectedS2SQL(querySqlReq.getSql());
|
||||
sqlInfo.setParsedS2SQL(querySqlReq.getSql());
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||
|
||||
@@ -630,7 +626,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
corrector.correct(queryCtx, semanticParseInfo);
|
||||
}
|
||||
});
|
||||
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectS2SQL());
|
||||
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL());
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
@@ -28,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
@@ -228,12 +227,11 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
|
||||
@S2DataPermission
|
||||
@Override
|
||||
public <T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception {
|
||||
T queryReq = translateSqlReq.getQueryReq();
|
||||
QueryStatement queryStatement = buildQueryStatement((SemanticQueryReq) queryReq, user);
|
||||
public SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception {
|
||||
QueryStatement queryStatement = buildQueryStatement(queryReq, user);
|
||||
semanticTranslator.translate(queryStatement);
|
||||
return TranslateResp.builder()
|
||||
.sql(queryStatement.getSql())
|
||||
return SemanticTranslateResp.builder()
|
||||
.querySQL(queryStatement.getSql())
|
||||
.isOk(queryStatement.isOk())
|
||||
.errMsg(queryStatement.getErrMsg())
|
||||
.build();
|
||||
|
||||
@@ -52,12 +52,12 @@ public class ParseInfoProcessor implements ResultProcessor {
|
||||
|
||||
public void updateParseInfo(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
if (StringUtils.isBlank(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
// if S2SQL equals correctS2SQL, then not update the parseInfo.
|
||||
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
|
||||
if (correctS2SQL.equals(sqlInfo.getParsedS2SQL())) {
|
||||
return;
|
||||
}
|
||||
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
|
||||
@@ -87,15 +87,15 @@ public class ParseInfoProcessor implements ResultProcessor {
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectedS2SQL()));
|
||||
Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
|
||||
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectedS2SQL());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectedS2SQL());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));
|
||||
}
|
||||
|
||||
@@ -4,11 +4,9 @@ import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
|
||||
@@ -123,15 +121,13 @@ public class ChatWorkflowEngine {
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
|
||||
.queryTypeEnum(QueryMethod.SQL).build();
|
||||
TranslateResp explain = queryService.translate(translateSqlReq, chatQueryContext.getUser());
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
|
||||
SemanticTranslateResp explain = queryService.translate(semanticQueryReq, chatQueryContext.getUser());
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
|
||||
keyPipelineLog.info("SqlInfoProcessor results:\n"
|
||||
+ "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}",
|
||||
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getS2SQL()),
|
||||
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectS2SQL()),
|
||||
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getParsedS2SQL()),
|
||||
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectedS2SQL()),
|
||||
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getQuerySQL()));
|
||||
} catch (Exception e) {
|
||||
log.warn("get sql info failed:{}", parseInfo, e);
|
||||
|
||||
@@ -2,11 +2,8 @@ package com.tencent.supersonic.headless;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -21,29 +18,22 @@ public class TranslateTest extends BaseTest {
|
||||
@Test
|
||||
public void testSqlExplain() throws Exception {
|
||||
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
|
||||
TranslateSqlReq<QuerySqlReq> translateSqlReq = TranslateSqlReq.<QuerySqlReq>builder()
|
||||
.queryTypeEnum(QueryMethod.SQL)
|
||||
.queryReq(QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()))
|
||||
.build();
|
||||
TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser());
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(QueryReqBuilder.buildS2SQLReq(sql,
|
||||
DataUtils.getMetricAgentView()), User.getFakeUser());
|
||||
assertNotNull(explain);
|
||||
assertNotNull(explain.getSql());
|
||||
assertTrue(explain.getSql().contains("department"));
|
||||
assertTrue(explain.getSql().contains("pv"));
|
||||
assertNotNull(explain.getQuerySQL());
|
||||
assertTrue(explain.getQuerySQL().contains("department"));
|
||||
assertTrue(explain.getQuerySQL().contains("pv"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStructExplain() throws Exception {
|
||||
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
|
||||
TranslateSqlReq<QueryStructReq> translateSqlReq = TranslateSqlReq.<QueryStructReq>builder()
|
||||
.queryTypeEnum(QueryMethod.STRUCT)
|
||||
.queryReq(queryStructReq)
|
||||
.build();
|
||||
TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser());
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(queryStructReq, User.getFakeUser());
|
||||
assertNotNull(explain);
|
||||
assertNotNull(explain.getSql());
|
||||
assertTrue(explain.getSql().contains("department"));
|
||||
assertTrue(explain.getSql().contains("pv"));
|
||||
assertNotNull(explain.getQuerySQL());
|
||||
assertTrue(explain.getQuerySQL().contains("department"));
|
||||
assertTrue(explain.getQuerySQL().contains("pv"));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user