(improvement)(headless)Remove unnecessary TranslateSqlReq, use SemanticQueryReq instead.

This commit is contained in:
jerryjzhang
2024-07-09 10:48:48 +08:00
parent 7a376bd9a3
commit f0b4eb46cf
32 changed files with 138 additions and 176 deletions

View File

@@ -44,7 +44,7 @@ public class SqlExecutor implements ChatExecutor {
.agentId(chatExecuteContext.getAgentId()) .agentId(chatExecuteContext.getAgentId())
.status(MemoryStatus.PENDING) .status(MemoryStatus.PENDING)
.question(chatExecuteContext.getQueryText()) .question(chatExecuteContext.getQueryText())
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL()) .s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getParsedS2SQL())
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo())) .dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
.createdBy(chatExecuteContext.getUser().getName()) .createdBy(chatExecuteContext.getUser().getName())
.updatedBy(chatExecuteContext.getUser().getName()) .updatedBy(chatExecuteContext.getUser().getName())
@@ -64,12 +64,12 @@ public class SqlExecutor implements ChatExecutor {
ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId()); ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo(); SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
if (Objects.isNull(parseInfo.getSqlInfo()) if (Objects.isNull(parseInfo.getSqlInfo())
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) { || StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
return null; return null;
} }
QuerySqlReq sqlReq = QuerySqlReq.builder() QuerySqlReq sqlReq = QuerySqlReq.builder()
.sql(parseInfo.getSqlInfo().getCorrectS2SQL()) .sql(parseInfo.getSqlInfo().getCorrectedS2SQL())
.build(); .build();
sqlReq.setSqlInfo(parseInfo.getSqlInfo()); sqlReq.setSqlInfo(parseInfo.getSqlInfo());
sqlReq.setDataSetId(parseInfo.getDataSetId()); sqlReq.setDataSetId(parseInfo.getDataSetId());

View File

@@ -161,7 +161,7 @@ public class NL2SQLParser implements ChatParser {
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId)); String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL(); String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL();
String rewrittenQuery = rewriteQuery(RewriteContext.builder() String rewrittenQuery = rewriteQuery(RewriteContext.builder()
.curtQuestion(currentMapResult.getQueryText()) .curtQuestion(currentMapResult.getQueryText())
.histQuestion(lastParseResult.getQueryText()) .histQuestion(lastParseResult.getQueryText())

View File

@@ -5,8 +5,13 @@ import lombok.Data;
@Data @Data
public class SqlInfo { public class SqlInfo {
private String s2SQL; // S2SQL generated by semantic parsers
private String correctS2SQL; private String parsedS2SQL;
// S2SQL corrected by semantic correctors
private String correctedS2SQL;
// SQL to be executed finally
private String querySQL; private String querySQL;
private String sourceId;
} }

View File

@@ -164,7 +164,7 @@ public class QueryStructReq extends SemanticQueryReq {
result.setDataSetId(this.getDataSetId()); result.setDataSetId(this.getDataSetId());
result.setModelIds(this.getModelIdSet()); result.setModelIds(this.getModelIdSet());
result.setParams(new ArrayList<>()); result.setParams(new ArrayList<>());
result.getSqlInfo().setCorrectS2SQL(sql); result.getSqlInfo().setCorrectedS2SQL(sql);
return result; return result;
} }

View File

@@ -1,20 +0,0 @@
package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
@Data
@ToString
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class TranslateSqlReq<T> {
private QueryMethod queryTypeEnum;
private T queryReq;
}

View File

@@ -13,10 +13,12 @@ import java.io.Serializable;
@Builder @Builder
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class TranslateResp implements Serializable { public class SemanticTranslateResp implements Serializable {
private String querySQL;
private String sql;
private boolean isOk; private boolean isOk;
private String errMsg; private String errMsg;
} }

View File

@@ -22,7 +22,7 @@ public class AggCorrector extends BaseSemanticCorrector {
private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields( List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectS2SQL()); semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) { if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return; return;
} }

View File

@@ -29,7 +29,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
try { try {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) { if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
return; return;
} }
doCorrect(chatQueryContext, semanticParseInfo); doCorrect(chatQueryContext, semanticParseInfo);
@@ -74,7 +74,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric //add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Long dataSetId = semanticParseInfo.getDataSet().getDataSet(); Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId); List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
@@ -98,7 +98,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return; return;
} }
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate); String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
} }
protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) { protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) {

View File

@@ -34,8 +34,8 @@ public class GrammarCorrector extends BaseSemanticCorrector {
} }
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) { public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL); correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
} }

View File

@@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
Long dataSetId = semanticParseInfo.getDataSetId(); Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by //add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL(); String correctS2SQL = sqlInfo.getCorrectedS2SQL();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
// check has distinct // check has distinct
if (SqlSelectHelper.hasDistinct(correctS2SQL)) { if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
@@ -68,7 +68,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
Long dataSetId = semanticParseInfo.getDataSetId(); Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by //add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL(); String correctS2SQL = sqlInfo.getCorrectedS2SQL();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
//add alias field name //add alias field name
Set<String> dimensions = getDimensions(dataSetId, semanticSchema); Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
@@ -83,7 +83,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return true; return true;
}) })
.collect(Collectors.toSet()); .collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields)); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
} }
} }

View File

@@ -49,19 +49,19 @@ public class HavingCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(metrics)) { if (CollectionUtils.isEmpty(metrics)) {
return; return;
} }
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics); String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), metrics);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(havingSql);
} }
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) { private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) { if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
return; return;
} }
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL); List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
if (!CollectionUtils.isEmpty(havingExpressionList)) { if (!CollectionUtils.isEmpty(havingExpressionList)) {
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList); String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
} }
return; return;
} }

View File

@@ -50,21 +50,21 @@ public class SchemaCorrector extends BaseSemanticCorrector {
private void correctAggFunction(SemanticParseInfo semanticParseInfo) { private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum(); Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum); String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectedS2SQL(), aggregateEnum);
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
} }
private void replaceAlias(SemanticParseInfo semanticParseInfo) { private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL()); String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectedS2SQL());
sqlInfo.setCorrectS2SQL(replaceAlias); sqlInfo.setCorrectedS2SQL(replaceAlias);
} }
private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId()); Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap); String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
} }
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) { private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
@@ -79,8 +79,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames); String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
} }
private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) { private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
@@ -111,14 +111,14 @@ public class SchemaCorrector extends BaseSemanticCorrector {
))); )));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false); String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
} }
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext, public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) { SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL(); String correctS2SQL = sqlInfo.getCorrectedS2SQL();
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL); List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
if (CollectionUtils.isEmpty(whereExpressionList)) { if (CollectionUtils.isEmpty(whereExpressionList)) {
return; return;
@@ -143,7 +143,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet()); .map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
} }
} }

View File

@@ -33,7 +33,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL); List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL); List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select. // If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
@@ -44,7 +44,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
} }
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL); correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL); String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(querySql);
} }
protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo, protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
@@ -65,7 +65,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
} }
needAddFields.removeAll(selectFields); needAddFields.removeAll(selectFields);
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields)); String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
return addFieldsToSelectSql; return addFieldsToSelectSql;
} }

View File

@@ -45,7 +45,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
} }
private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
//decide whether remove date field from where //decide whether remove date field from where
Environment environment = ContextUtils.getBean(Environment.class); Environment environment = ContextUtils.getBean(Environment.class);
String correctorDate = environment.getProperty("s2.corrector.date"); String correctorDate = environment.getProperty("s2.corrector.date");
@@ -55,12 +55,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName()); removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName()); removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
} }
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
//decide whether add date field to where //decide whether add date field to where
@@ -88,11 +88,11 @@ public class TimeCorrector extends BaseSemanticCorrector {
} }
} }
} }
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) { private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL); DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
if (Objects.isNull(dateBoundInfo)) { if (Objects.isNull(dateBoundInfo)) {
return; return;
@@ -108,14 +108,14 @@ public class TimeCorrector extends BaseSemanticCorrector {
} catch (JSQLParserException e) { } catch (JSQLParserException e) {
log.error("parseCondExpression", e); log.error("parseCondExpression", e);
} }
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
} }
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) { private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL); correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
} }

View File

@@ -39,7 +39,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters()); String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) { if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to correctS2SQL :{}", queryFilter); log.info("add queryFilter to correctS2SQL :{}", queryFilter);
@@ -50,7 +50,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
log.error("parseCondExpression", e); log.error("parseCondExpression", e);
} }
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression); correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
} }
@@ -71,9 +71,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
} }
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions); Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(),
aliasAndBizNameToTechName); aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) { private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {

View File

@@ -46,14 +46,14 @@ public class QueryTypeParser implements SemanticParser {
private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) { private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo(); SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) { if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
return QueryType.DETAIL; return QueryType.DETAIL;
} }
//1. entity queryType //1. entity queryType
Long dataSetId = parseInfo.getDataSetId(); Long dataSetId = parseInfo.getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) { if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL()); List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields); List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) { if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName) Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
@@ -63,7 +63,7 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.ID; return QueryType.ID;
} }
} }
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
selectFields.addAll(whereFields); selectFields.addAll(whereFields);
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields); List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) { if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
@@ -91,7 +91,7 @@ public class QueryTypeParser implements SemanticParser {
} }
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) { private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId); List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) { if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet()); Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());

View File

@@ -38,7 +38,7 @@ public class LLMResponseService {
parseInfo.setProperties(properties); parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight)); parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2SQL(s2SQL); parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery); queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo; return parseInfo;
} }

View File

@@ -83,8 +83,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
QueryStructReq queryStructReq = convertQueryStruct(); QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(semanticSchema, queryStructReq); convertBizNameToName(semanticSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert(); QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql()); parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql()); parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
} }
} }

View File

@@ -33,6 +33,6 @@ public class LLMSqlQuery extends LLMSemanticQuery {
@Override @Override
public void initS2Sql(SemanticSchema semanticSchema, User user) { public void initS2Sql(SemanticSchema semanticSchema, User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo(); SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL()); sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
} }
} }

View File

@@ -152,8 +152,8 @@ public class QueryReqBuilder {
public static QuerySqlReq buildS2SQLReq(SqlInfo sqlInfo, Long dataSetId) { public static QuerySqlReq buildS2SQLReq(SqlInfo sqlInfo, Long dataSetId) {
QuerySqlReq querySQLReq = new QuerySqlReq(); QuerySqlReq querySQLReq = new QuerySqlReq();
if (Objects.nonNull(sqlInfo.getCorrectS2SQL())) { if (Objects.nonNull(sqlInfo.getCorrectedS2SQL())) {
querySQLReq.setSql(sqlInfo.getCorrectS2SQL()); querySQLReq.setSql(sqlInfo.getCorrectedS2SQL());
} }
querySQLReq.setSqlInfo(sqlInfo); querySQLReq.setSqlInfo(sqlInfo);
querySQLReq.setDataSetId(dataSetId); querySQLReq.setDataSetId(dataSetId);

View File

@@ -30,14 +30,14 @@ class AggCorrectorTest {
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND" String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
+ " datediff('day', 数据日期, '2024-06-04') <= 7" + " datediff('day', 数据日期, '2024-06-04') <= 7"
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1"; + " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
sqlInfo.setS2SQL(sql); sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo); semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo); corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'" Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户" + " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
+ " ORDER BY SUM(访问次数) DESC LIMIT 1", + " ORDER BY SUM(访问次数) DESC LIMIT 1",
semanticParseInfo.getSqlInfo().getCorrectS2SQL()); semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
} }
private ChatQueryContext buildQueryContext(Long dataSetId) { private ChatQueryContext buildQueryContext(Long dataSetId) {

View File

@@ -65,8 +65,8 @@ class SchemaCorrectorTest {
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10"; + "and 商务组 = 'xxx' order by 播放量 desc limit 10";
SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo(); SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setS2SQL(sql); sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo); semanticParseInfo.setSqlInfo(sqlInfo);
SchemaElement schemaElement = new SchemaElement(); SchemaElement schemaElement = new SchemaElement();
@@ -80,7 +80,7 @@ class SchemaCorrectorTest {
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
parseResult = objectMapper.readValue(json, ParseResult.class); parseResult = objectMapper.readValue(json, ParseResult.class);
@@ -92,11 +92,11 @@ class SchemaCorrectorTest {
parseResult.setLinkingValues(linkingValues); parseResult.setLinkingValues(linkingValues);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
semanticParseInfo.getSqlInfo().setS2SQL(sql); semanticParseInfo.getSqlInfo().setParsedS2SQL(sql);
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
} }

View File

@@ -44,12 +44,12 @@ class SelectCorrectorTest {
semanticParseInfo.setQueryType(QueryType.DETAIL); semanticParseInfo.setQueryType(QueryType.DETAIL);
SqlInfo sqlInfo = new SqlInfo(); SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'"; String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
sqlInfo.setS2SQL(sql); sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo); semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo); corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'", Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
semanticParseInfo.getSqlInfo().getCorrectS2SQL()); semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
} }
private ChatQueryContext buildQueryContext(Long dataSetId) { private ChatQueryContext buildQueryContext(Long dataSetId) {

View File

@@ -19,84 +19,84 @@ class TimeCorrectorTest {
//1.数据日期 <= //1.数据日期 <=
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo); semanticParseInfo.setSqlInfo(sqlInfo);
corrector.doCorrect(chatQueryContext, semanticParseInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals( Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') " "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1", + "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL()); sqlInfo.getCorrectedS2SQL());
//2.数据日期 < //2.数据日期 <
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals( Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') " "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1", + "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL()); sqlInfo.getCorrectedS2SQL());
//3.数据日期 >= //3.数据日期 >=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals( Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 " "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1", + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL()); sqlInfo.getCorrectedS2SQL());
//4.数据日期 > //4.数据日期 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals( Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 " "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1", + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL()); sqlInfo.getCorrectedS2SQL());
//5.no 数据日期 //5.no 数据日期
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' GROUP BY 维度1"; + "WHERE 歌手名 = '张三' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals( Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1", "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL()); sqlInfo.getCorrectedS2SQL());
//6. 数据日期-月 <= //6. 数据日期-月 <=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1"; + "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals( Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') " "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1", + "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL()); sqlInfo.getCorrectedS2SQL());
//7. 数据日期-月 > //7. 数据日期-月 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1"; + "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals( Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 " "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1", + "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL()); sqlInfo.getCorrectedS2SQL());
//8. no where //8. no where
sql = "SELECT COUNT(1) FROM 数据库"; sql = "SELECT COUNT(1) FROM 数据库";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL()); Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectedS2SQL());
} }
} }

View File

@@ -19,7 +19,7 @@ class WhereCorrectorTest {
SqlInfo sqlInfo = new SqlInfo(); SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo); semanticParseInfo.setSqlInfo(sqlInfo);
ChatQueryContext chatQueryContext = new ChatQueryContext(); ChatQueryContext chatQueryContext = new ChatQueryContext();
@@ -54,7 +54,7 @@ class WhereCorrectorTest {
WhereCorrector whereCorrector = new WhereCorrector(); WhereCorrector whereCorrector = new WhereCorrector();
whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo); whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo);
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE " Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND " + "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "

View File

@@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
@@ -73,9 +72,6 @@ public class S2DataPermissionAspect {
SemanticQueryReq queryReq = null; SemanticQueryReq queryReq = null;
if (objects[0] instanceof SemanticQueryReq) { if (objects[0] instanceof SemanticQueryReq) {
queryReq = (SemanticQueryReq) objects[0]; queryReq = (SemanticQueryReq) objects[0];
} else if (objects[0] instanceof TranslateSqlReq) {
queryReq = (SemanticQueryReq) ((TranslateSqlReq<?>) objects[0]).getQueryReq();
needQueryData = false;
} }
if (queryReq == null) { if (queryReq == null) {
throw new InvalidArgumentException("queryReq is not Invalid"); throw new InvalidArgumentException("queryReq is not Invalid");

View File

@@ -4,10 +4,9 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp; import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -20,12 +19,12 @@ public interface SemanticLayerService {
DataSetSchema getDataSetSchema(Long id); DataSetSchema getDataSetSchema(Long id);
SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception;
SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception; SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception;
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user); SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
<T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception;
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user); EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user);
List<ItemResp> getDomainDataSetTree(); List<ItemResp> getDomainDataSetTree();

View File

@@ -26,11 +26,9 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.CostType; import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
@@ -41,7 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo; import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -247,8 +245,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
List<String> fields = new ArrayList<>(); List<String> fields = new ArrayList<>();
if (Objects.nonNull(parseInfo.getSqlInfo()) if (Objects.nonNull(parseInfo.getSqlInfo())
&& StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) { && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
fields = SqlSelectHelper.getAllFields(correctorSql); fields = SqlSelectHelper.getAllFields(correctorSql);
} }
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode()) if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())
@@ -260,13 +258,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { } else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
log.info("llm begin revise filters!"); log.info("llm begin revise filters!");
String correctorSql = reviseCorrectS2SQL(queryData, parseInfo); String correctorSql = reviseCorrectS2SQL(queryData, parseInfo);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq) SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
.queryTypeEnum(QueryMethod.SQL).build(); parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
TranslateResp explain = semanticLayerService.translate(translateSqlReq, user);
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
} else { } else {
log.info("rule begin replace metrics and revise filters!"); log.info("rule begin replace metrics and revise filters!");
//remove unvalid filters //remove unvalid filters
@@ -303,7 +299,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>(); Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>(); Map<String, Map<String, String>> havingFiledNameToValueMap = new HashMap<>();
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql); log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter // get where filter and having filter
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql); List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
@@ -334,7 +330,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = parseInfo.getMetrics().stream() List<String> oriMetrics = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList()); .map(SchemaElement::getName).collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("before replaceMetrics:{}", correctorSql); log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric); log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
Map<String, Pair<String, String>> fieldMap = new HashMap<>(); Map<String, Pair<String, String>> fieldMap = new HashMap<>();
@@ -343,7 +339,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap); correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
} }
log.info("after replaceMetrics:{}", correctorSql); log.info("after replaceMetrics:{}", correctorSql);
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
} }
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
@@ -598,7 +594,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
public void correct(QuerySqlReq querySqlReq, User user) { public void correct(QuerySqlReq querySqlReq, User user) {
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user); SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectS2SQL()); querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
} }
@Override @Override
@@ -613,8 +609,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
queryCtx.setSemanticSchema(semanticSchema); queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo(); SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectS2SQL(querySqlReq.getSql()); sqlInfo.setCorrectedS2SQL(querySqlReq.getSql());
sqlInfo.setS2SQL(querySqlReq.getSql()); sqlInfo.setParsedS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo); semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setQueryType(QueryType.DETAIL); semanticParseInfo.setQueryType(QueryType.DETAIL);
@@ -630,7 +626,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
corrector.correct(queryCtx, semanticParseInfo); corrector.correct(queryCtx, semanticParseInfo);
} }
}); });
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectS2SQL()); log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL());
return semanticParseInfo; return semanticParseInfo;
} }

View File

@@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
@@ -28,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp; import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -228,12 +227,11 @@ public class S2SemanticLayerService implements SemanticLayerService {
@S2DataPermission @S2DataPermission
@Override @Override
public <T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception { public SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception {
T queryReq = translateSqlReq.getQueryReq(); QueryStatement queryStatement = buildQueryStatement(queryReq, user);
QueryStatement queryStatement = buildQueryStatement((SemanticQueryReq) queryReq, user);
semanticTranslator.translate(queryStatement); semanticTranslator.translate(queryStatement);
return TranslateResp.builder() return SemanticTranslateResp.builder()
.sql(queryStatement.getSql()) .querySQL(queryStatement.getSql())
.isOk(queryStatement.isOk()) .isOk(queryStatement.isOk())
.errMsg(queryStatement.getErrMsg()) .errMsg(queryStatement.getErrMsg())
.build(); .build();

View File

@@ -52,12 +52,12 @@ public class ParseInfoProcessor implements ResultProcessor {
public void updateParseInfo(SemanticParseInfo parseInfo) { public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo(); SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL(); String correctS2SQL = sqlInfo.getCorrectedS2SQL();
if (StringUtils.isBlank(correctS2SQL)) { if (StringUtils.isBlank(correctS2SQL)) {
return; return;
} }
// if S2SQL equals correctS2SQL, then not update the parseInfo. // if S2SQL equals correctS2SQL, then not update the parseInfo.
if (correctS2SQL.equals(sqlInfo.getS2SQL())) { if (correctS2SQL.equals(sqlInfo.getParsedS2SQL())) {
return; return;
} }
List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL); List<FieldExpression> expressions = SqlSelectHelper.getFilterExpression(correctS2SQL);
@@ -87,15 +87,15 @@ public class ParseInfoProcessor implements ResultProcessor {
if (Objects.isNull(semanticSchema)) { if (Objects.isNull(semanticSchema)) {
return; return;
} }
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL())); List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectedS2SQL()));
Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics()); Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics); parseInfo.setMetrics(metrics);
if (QueryType.METRIC.equals(parseInfo.getQueryType())) { if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL()); List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectedS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields); List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions())); parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
} else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) { } else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL()); List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectedS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields); List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions())); parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));
} }

View File

@@ -4,11 +4,9 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.ChatContext; import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector; import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
@@ -123,15 +121,13 @@ public class ChatWorkflowEngine {
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq) SemanticTranslateResp explain = queryService.translate(semanticQueryReq, chatQueryContext.getUser());
.queryTypeEnum(QueryMethod.SQL).build(); parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
TranslateResp explain = queryService.translate(translateSqlReq, chatQueryContext.getUser());
parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
keyPipelineLog.info("SqlInfoProcessor results:\n" keyPipelineLog.info("SqlInfoProcessor results:\n"
+ "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}", + "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}",
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getS2SQL()), StringUtils.normalizeSpace(parseInfo.getSqlInfo().getParsedS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectS2SQL()), StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectedS2SQL()),
StringUtils.normalizeSpace(parseInfo.getSqlInfo().getQuerySQL())); StringUtils.normalizeSpace(parseInfo.getSqlInfo().getQuerySQL()));
} catch (Exception e) { } catch (Exception e) {
log.warn("get sql info failed:{}", parseInfo, e); log.warn("get sql info failed:{}", parseInfo, e);

View File

@@ -2,11 +2,8 @@ package com.tencent.supersonic.headless;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -21,29 +18,22 @@ public class TranslateTest extends BaseTest {
@Test @Test
public void testSqlExplain() throws Exception { public void testSqlExplain() throws Exception {
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
TranslateSqlReq<QuerySqlReq> translateSqlReq = TranslateSqlReq.<QuerySqlReq>builder() SemanticTranslateResp explain = semanticLayerService.translate(QueryReqBuilder.buildS2SQLReq(sql,
.queryTypeEnum(QueryMethod.SQL) DataUtils.getMetricAgentView()), User.getFakeUser());
.queryReq(QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()))
.build();
TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getSql()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getSql().contains("department")); assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getSql().contains("pv")); assertTrue(explain.getQuerySQL().contains("pv"));
} }
@Test @Test
public void testStructExplain() throws Exception { public void testStructExplain() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department")); QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
TranslateSqlReq<QueryStructReq> translateSqlReq = TranslateSqlReq.<QueryStructReq>builder() SemanticTranslateResp explain = semanticLayerService.translate(queryStructReq, User.getFakeUser());
.queryTypeEnum(QueryMethod.STRUCT)
.queryReq(queryStructReq)
.build();
TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getSql()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getSql().contains("department")); assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getSql().contains("pv")); assertTrue(explain.getQuerySQL().contains("pv"));
} }
} }