(improvement)(headless)Remove unnecessary TranslateSqlReq, use SemanticQueryReq instead.

This commit is contained in:
jerryjzhang
2024-07-09 10:48:48 +08:00
parent 7a376bd9a3
commit f0b4eb46cf
32 changed files with 138 additions and 176 deletions

View File

@@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
@@ -73,9 +72,6 @@ public class S2DataPermissionAspect {
SemanticQueryReq queryReq = null;
if (objects[0] instanceof SemanticQueryReq) {
queryReq = (SemanticQueryReq) objects[0];
} else if (objects[0] instanceof TranslateSqlReq) {
queryReq = (SemanticQueryReq) ((TranslateSqlReq<?>) objects[0]).getQueryReq();
needQueryData = false;
}
if (queryReq == null) {
throw new InvalidArgumentException("queryReq is not Invalid");

View File

@@ -4,10 +4,9 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -20,12 +19,12 @@ public interface SemanticLayerService {
DataSetSchema getDataSetSchema(Long id);
SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception;
SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception;
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
<T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception;
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user);
List<ItemResp> getDomainDataSetTree();

View File

@@ -26,11 +26,9 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
@@ -41,7 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -247,8 +245,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
fields = SqlSelectHelper.getAllFields(correctorSql);
}
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
@@ -260,13 +258,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build();
TranslateResp explain = semanticLayerService.translate(translateSqlReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
} else {
log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters
@@ -303,7 +299,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
@@ -334,7 +330,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
@@ -343,7 +339,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
}
log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
@@ -598,7 +594,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
public void correct(QuerySqlReq querySqlReq, User user) {
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectS2SQL());
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
@Override
@@ -613,8 +609,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectS2SQL(querySqlReq.getSql());
sqlInfo.setS2SQL(querySqlReq.getSql());
sqlInfo.setCorrectedS2SQL(querySqlReq.getSql());
sqlInfo.setParsedS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setQueryType(QueryType.DETAIL);
@@ -630,7 +626,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
corrector.correct(queryCtx, semanticParseInfo);
}
});
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectS2SQL());
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL());
return semanticParseInfo;
}

View File

@@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
@@ -28,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -228,12 +227,11 @@ public class S2SemanticLayerService implements SemanticLayerService {
@S2DataPermission
@Override
public <T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception {
T queryReq = translateSqlReq.getQueryReq();
QueryStatement queryStatement = buildQueryStatement((SemanticQueryReq) queryReq, user);
public SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception {
QueryStatement queryStatement = buildQueryStatement(queryReq, user);
semanticTranslator.translate(queryStatement);
return TranslateResp.builder()
.sql(queryStatement.getSql())
return SemanticTranslateResp.builder()
.querySQL(queryStatement.getSql())
.isOk(queryStatement.isOk())
.errMsg(queryStatement.getErrMsg())
.build();

View File

@@ -52,12 +52,12 @@ public class ParseInfoProcessor implements ResultProcessor {
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
if (StringUtils.isBlank(correctS2SQL)) {
return;
}
// if S2SQL equals correctS2SQL, then not update the parseInfo.
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
if (correctS2SQL.equals(sqlInfo.getParsedS2SQL())) {
return;
}
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
@@ -87,15 +87,15 @@ public class ParseInfoProcessor implements ResultProcessor {
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectedS2SQL()));
Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectedS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
} else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectedS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));
}

View File

@@ -4,11 +4,9 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
@@ -123,15 +121,13 @@ public class ChatWorkflowEngine {
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build();
TranslateResp explain = queryService.translate(translateSqlReq, chatQueryContext.getUser());
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
SemanticTranslateResp explain = queryService.translate(semanticQueryReq, chatQueryContext.getUser());
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
keyPipelineLog.info("SqlInfoProcessor results:\n"
+ "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}",
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getParsedS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectedS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getQuerySQL()));
} catch (Exception e) {
log.warn("get sql info failed:{}", parseInfo, e);