(improvement)(headless)Remove unnecessary TranslateSqlReq, use SemanticQueryReq instead.

This commit is contained in:
jerryjzhang
2024-07-09 10:48:48 +08:00
parent 7a376bd9a3
commit f0b4eb46cf
32 changed files with 138 additions and 176 deletions

View File

@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatExecutor {
.agentId(chatExecuteContext.getAgentId())
.status(MemoryStatus.PENDING)
.question(chatExecuteContext.getQueryText())
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getParsedS2SQL())
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
.createdBy(chatExecuteContext.getUser().getName())
.updatedBy(chatExecuteContext.getUser().getName())
@@ -64,12 +64,12 @@ public class SqlExecutor implements ChatExecutor {
ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
if (Objects.isNull(parseInfo.getSqlInfo())
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
return null;
}
QuerySqlReq sqlReq = QuerySqlReq.builder()
.sql(parseInfo.getSqlInfo().getCorrectS2SQL())
.sql(parseInfo.getSqlInfo().getCorrectedS2SQL())
.build();
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
sqlReq.setDataSetId(parseInfo.getDataSetId());

View File

@@ -161,7 +161,7 @@ public class NL2SQLParser implements ChatParser {
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL();
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
.curtQuestion(currentMapResult.getQueryText())
.histQuestion(lastParseResult.getQueryText())

View File

@@ -5,8 +5,13 @@ import lombok.Data;
@Data
public class SqlInfo {
private String s2SQL;
private String correctS2SQL;
// S2SQL generated by semantic parsers
private String parsedS2SQL;
// S2SQL corrected by semantic correctors
private String correctedS2SQL;
// SQL to be executed finally
private String querySQL;
private String sourceId;
}

View File

@@ -164,7 +164,7 @@ public class QueryStructReq extends SemanticQueryReq {
result.setDataSetId(this.getDataSetId());
result.setModelIds(this.getModelIdSet());
result.setParams(new ArrayList<>());
result.getSqlInfo().setCorrectS2SQL(sql);
result.getSqlInfo().setCorrectedS2SQL(sql);
return result;
}

View File

@@ -1,20 +0,0 @@
package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
@Data
@ToString
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class TranslateSqlReq<T> {
private QueryMethod queryTypeEnum;
private T queryReq;
}

View File

@@ -13,10 +13,12 @@ import java.io.Serializable;
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class TranslateResp implements Serializable {
public class SemanticTranslateResp implements Serializable {
private String querySQL;
private String sql;
private boolean isOk;
private String errMsg;
}

View File

@@ -22,7 +22,7 @@ public class AggCorrector extends BaseSemanticCorrector {
private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}

View File

@@ -29,7 +29,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
try {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
return;
}
doCorrect(chatQueryContext, semanticParseInfo);
@@ -74,7 +74,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
@@ -98,7 +98,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return;
}
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) {

View File

@@ -34,8 +34,8 @@ public class GrammarCorrector extends BaseSemanticCorrector {
}
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}

View File

@@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
// check has distinct
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
@@ -68,7 +68,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
//add alias field name
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
@@ -83,7 +83,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return true;
})
.collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
}
}

View File

@@ -49,19 +49,19 @@ public class HavingCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(metrics)) {
return;
}
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), metrics);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(havingSql);
}
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
return;
}
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
if (!CollectionUtils.isEmpty(havingExpressionList)) {
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
}
return;
}

View File

@@ -50,21 +50,21 @@ public class SchemaCorrector extends BaseSemanticCorrector {
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
sqlInfo.setCorrectS2SQL(sql);
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectedS2SQL(), aggregateEnum);
sqlInfo.setCorrectedS2SQL(sql);
}
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
sqlInfo.setCorrectS2SQL(replaceAlias);
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectedS2SQL());
sqlInfo.setCorrectedS2SQL(replaceAlias);
}
private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
sqlInfo.setCorrectedS2SQL(sql);
}
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
@@ -79,8 +79,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
sqlInfo.setCorrectS2SQL(sql);
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
sqlInfo.setCorrectedS2SQL(sql);
}
private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
@@ -111,14 +111,14 @@ public class SchemaCorrector extends BaseSemanticCorrector {
)));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql);
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectedS2SQL(sql);
}
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
if (CollectionUtils.isEmpty(whereExpressionList)) {
return;
@@ -143,7 +143,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
}
}

View File

@@ -33,7 +33,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
@@ -44,7 +44,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
}
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(querySql);
}
protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
@@ -65,7 +65,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
}
needAddFields.removeAll(selectFields);
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
return addFieldsToSelectSql;
}

View File

@@ -45,7 +45,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
}
private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
//decide whether remove date field from where
Environment environment = ContextUtils.getBean(Environment.class);
String correctorDate = environment.getProperty("s2.corrector.date");
@@ -55,12 +55,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
//decide whether add date field to where
@@ -88,11 +88,11 @@ public class TimeCorrector extends BaseSemanticCorrector {
}
}
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
if (Objects.isNull(dateBoundInfo)) {
return;
@@ -108,14 +108,14 @@ public class TimeCorrector extends BaseSemanticCorrector {
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}

View File

@@ -39,7 +39,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
@@ -50,7 +50,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
log.error("parseCondExpression", e);
}
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
@@ -71,9 +71,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {

View File

@@ -46,14 +46,14 @@ public class QueryTypeParser implements SemanticParser {
private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
return QueryType.DETAIL;
}
//1. entity queryType
Long dataSetId = parseInfo.getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
@@ -63,7 +63,7 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.ID;
}
}
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
selectFields.addAll(whereFields);
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
@@ -91,7 +91,7 @@ public class QueryTypeParser implements SemanticParser {
}
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());

View File

@@ -38,7 +38,7 @@ public class LLMResponseService {
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2SQL(s2SQL);
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
}

View File

@@ -83,8 +83,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(semanticSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
}
}

View File

@@ -33,6 +33,6 @@ public class LLMSqlQuery extends LLMSemanticQuery {
@Override
public void initS2Sql(SemanticSchema semanticSchema, User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL());
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
}
}

View File

@@ -152,8 +152,8 @@ public class QueryReqBuilder {
public static QuerySqlReq buildS2SQLReq(SqlInfo sqlInfo, Long dataSetId) {
QuerySqlReq querySQLReq = new QuerySqlReq();
if (Objects.nonNull(sqlInfo.getCorrectS2SQL())) {
querySQLReq.setSql(sqlInfo.getCorrectS2SQL());
if (Objects.nonNull(sqlInfo.getCorrectedS2SQL())) {
querySQLReq.setSql(sqlInfo.getCorrectedS2SQL());
}
querySQLReq.setSqlInfo(sqlInfo);
querySQLReq.setDataSetId(dataSetId);

View File

@@ -30,14 +30,14 @@ class AggCorrectorTest {
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
private ChatQueryContext buildQueryContext(Long dataSetId) {

View File

@@ -65,8 +65,8 @@ class SchemaCorrectorTest {
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
SchemaElement schemaElement = new SchemaElement();
@@ -80,7 +80,7 @@ class SchemaCorrectorTest {
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
parseResult = objectMapper.readValue(json, ParseResult.class);
@@ -92,11 +92,11 @@ class SchemaCorrectorTest {
parseResult.setLinkingValues(linkingValues);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
semanticParseInfo.getSqlInfo().setS2SQL(sql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
semanticParseInfo.getSqlInfo().setParsedS2SQL(sql);
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}

View File

@@ -44,12 +44,12 @@ class SelectCorrectorTest {
semanticParseInfo.setQueryType(QueryType.DETAIL);
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
private ChatQueryContext buildQueryContext(Long dataSetId) {

View File

@@ -19,84 +19,84 @@ class TimeCorrectorTest {
//1.数据日期 <=
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//2.数据日期 <
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//3.数据日期 >=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//4.数据日期 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//5.no 数据日期
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//6. 数据日期-月 <=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//7. 数据日期-月 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//8. no where
sql = "SELECT COUNT(1) FROM 数据库";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectedS2SQL());
}
}

View File

@@ -19,7 +19,7 @@ class WhereCorrectorTest {
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
ChatQueryContext chatQueryContext = new ChatQueryContext();
@@ -54,7 +54,7 @@ class WhereCorrectorTest {
WhereCorrector whereCorrector = new WhereCorrector();
whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo);
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "

View File

@@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
@@ -73,9 +72,6 @@ public class S2DataPermissionAspect {
SemanticQueryReq queryReq = null;
if (objects[0] instanceof SemanticQueryReq) {
queryReq = (SemanticQueryReq) objects[0];
} else if (objects[0] instanceof TranslateSqlReq) {
queryReq = (SemanticQueryReq) ((TranslateSqlReq<?>) objects[0]).getQueryReq();
needQueryData = false;
}
if (queryReq == null) {
throw new InvalidArgumentException("queryReq is not Invalid");

View File

@@ -4,10 +4,9 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -20,12 +19,12 @@ public interface SemanticLayerService {
DataSetSchema getDataSetSchema(Long id);
SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception;
SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception;
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
<T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception;
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user);
List<ItemResp> getDomainDataSetTree();

View File

@@ -26,11 +26,9 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
@@ -41,7 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -247,8 +245,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
fields = SqlSelectHelper.getAllFields(correctorSql);
}
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
@@ -260,13 +258,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build();
TranslateResp explain = semanticLayerService.translate(translateSqlReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
@@ -303,7 +299,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
@@ -334,7 +330,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
@@ -343,7 +339,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
@@ -598,7 +594,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
public void correct(QuerySqlReq querySqlReq, User user) {
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectS2SQL());
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
@Override
@@ -613,8 +609,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
sqlInfo.setS2SQL(querySqlReq.getSql());
sqlInfo.setCorrectedS2SQL(querySqlReq.getSql());
sqlInfo.setParsedS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setQueryType(QueryType.DETAIL);
@@ -630,7 +626,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
corrector.correct(queryCtx, semanticParseInfo);
}
});
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectS2SQL());
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL());
return semanticParseInfo;
}

View File

@@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
@@ -28,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -228,12 +227,11 @@ public class S2SemanticLayerService implements SemanticLayerService {
@S2DataPermission
@Override
public <T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception {
T queryReq = translateSqlReq.getQueryReq();
QueryStatement queryStatement = buildQueryStatement((SemanticQueryReq) queryReq, user);
public SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception {
QueryStatement queryStatement = buildQueryStatement(queryReq, user);
semanticTranslator.translate(queryStatement);
return TranslateResp.builder()
.sql(queryStatement.getSql())
return SemanticTranslateResp.builder()
.querySQL(queryStatement.getSql())
.isOk(queryStatement.isOk())
.errMsg(queryStatement.getErrMsg())
.build();

View File

@@ -52,12 +52,12 @@ public class ParseInfoProcessor implements ResultProcessor {
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
if (StringUtils.isBlank(correctS2SQL)) {
return;
}
// if S2SQL equals correctS2SQL, then not update the parseInfo.
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
if (correctS2SQL.equals(sqlInfo.getParsedS2SQL())) {
return;
}
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
@@ -87,15 +87,15 @@ public class ParseInfoProcessor implements ResultProcessor {
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectedS2SQL()));
Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectedS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
} else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectedS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));
}

View File

@@ -4,11 +4,9 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
@@ -123,15 +121,13 @@ public class ChatWorkflowEngine {
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build();
TranslateResp explain = queryService.translate(translateSqlReq, chatQueryContext.getUser());
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
SemanticTranslateResp explain = queryService.translate(semanticQueryReq, chatQueryContext.getUser());
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
keyPipelineLog.info("SqlInfoProcessor results:\n"
+ "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}",
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getParsedS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectedS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getQuerySQL()));
} catch (Exception e) {
log.warn("get sql info failed:{}", parseInfo, e);

View File

@@ -2,11 +2,8 @@ package com.tencent.supersonic.headless;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Test;
@@ -21,29 +18,22 @@ public class TranslateTest extends BaseTest {
@Test
public void testSqlExplain() throws Exception {
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
TranslateSqlReq<QuerySqlReq> translateSqlReq = TranslateSqlReq.<QuerySqlReq>builder()
.queryTypeEnum(QueryMethod.SQL)
.queryReq(QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()))
.build();
TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser());
SemanticTranslateResp explain = semanticLayerService.translate(QueryReqBuilder.buildS2SQLReq(sql,
DataUtils.getMetricAgentView()), User.getFakeUser());
assertNotNull(explain);
assertNotNull(explain.getSql());
assertTrue(explain.getSql().contains("department"));
assertTrue(explain.getSql().contains("pv"));
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getQuerySQL().contains("pv"));
}
@Test
public void testStructExplain() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
TranslateSqlReq<QueryStructReq> translateSqlReq = TranslateSqlReq.<QueryStructReq>builder()
.queryTypeEnum(QueryMethod.STRUCT)
.queryReq(queryStructReq)
.build();
TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser());
SemanticTranslateResp explain = semanticLayerService.translate(queryStructReq, User.getFakeUser());
assertNotNull(explain);
assertNotNull(explain.getSql());
assertTrue(explain.getSql().contains("department"));
assertTrue(explain.getSql().contains("pv"));
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getQuerySQL().contains("pv"));
}
}