diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index de06afcaf..d82108bd0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -91,6 +91,7 @@ public class NL2SQLParser implements ChatQueryParser { // mapModes Set requestedDatasets = queryNLReq.getDataSetIds(); List candidateParses = Lists.newArrayList(); + StringBuilder errMsg = new StringBuilder(); for (Long datasetId : requestedDatasets) { queryNLReq.setDataSetIds(Collections.singleton(datasetId)); ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId()); @@ -104,6 +105,7 @@ public class NL2SQLParser implements ChatQueryParser { doParse(queryNLReq, parseResp); } if (parseResp.getSelectedParses().isEmpty()) { + errMsg.append(parseResp.getErrorMsg()); continue; } // for one dataset select the top 1 parse after sorting @@ -116,6 +118,10 @@ public class NL2SQLParser implements ChatQueryParser { SemanticParseInfo.sort(candidateParses); parseContext.getResponse().setSelectedParses( candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size()))); + if (parseContext.getResponse().getSelectedParses().isEmpty()) { + parseContext.getResponse().setState(ParseResp.ParseState.FAILED); + parseContext.getResponse().setErrorMsg(errMsg.toString()); + } } // next go with llm-based parsers unless LLM is disabled or use feedback is needed. diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ColumnSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ColumnSchema.java index 79670baf0..1940cdf36 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ColumnSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ColumnSchema.java @@ -15,7 +15,7 @@ public class ColumnSchema { private FieldType filedType; - private AggOperatorEnum agg; + private AggOperatorEnum agg = AggOperatorEnum.UNKNOWN; private String name; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/FieldType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/FieldType.java index ce24f76c9..0f0e8afa3 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/FieldType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/FieldType.java @@ -1,5 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum FieldType { - primary_key, foreign_key, partition_time, time, dimension, measure; + primary_key, foreign_key, partition_time, time, categorical, measure; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java index 9385bb28b..8dc9ffa0a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java @@ -9,6 +9,10 @@ import java.util.List; @Data public class ModelBuildReq { + private String name; + + private String bizName; + private Long databaseId; private Long domainId; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java index 752f193c6..2880111f2 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java @@ -211,7 +211,9 @@ public class QueryStructReq extends SemanticQueryReq { SelectItem selectExpressionItem = new SelectItem(function); String alias = StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName; - selectExpressionItem.setAlias(new Alias(alias)); + if (!alias.equals(columnName)) { + selectExpressionItem.setAlias(new Alias(alias)); + } return selectExpressionItem; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java index b03521996..a61cc7d76 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java @@ -84,7 +84,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor { return connection.getMetaData(); } - protected static FieldType classifyColumnType(String typeName) { + public FieldType classifyColumnType(String typeName) { switch (typeName.toUpperCase()) { case "INT": case "INTEGER": @@ -101,7 +101,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor { case "TIMESTAMP": return FieldType.time; default: - return FieldType.dimension; + return FieldType.categorical; } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/DbAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/DbAdaptor.java index f2e650177..94862f633 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/DbAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/DbAdaptor.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.core.adaptor.db; import com.tencent.supersonic.headless.api.pojo.DBColumn; +import com.tencent.supersonic.headless.api.pojo.enums.FieldType; import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import java.sql.SQLException; @@ -19,4 +20,6 @@ public interface DbAdaptor { List getColumns(ConnectInfo connectInfo, String schemaName, String tableName) throws SQLException; + + FieldType classifyColumnType(String typeName); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java index ff4cefe5a..f04fb0e5f 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java @@ -114,7 +114,8 @@ public class PostgresqlAdaptor extends BaseDbAdaptor { return dbColumns; } - protected static FieldType classifyColumnType(String typeName) { + @Override + public FieldType classifyColumnType(String typeName) { switch (typeName.toUpperCase()) { case "INT": case "INTEGER": @@ -141,7 +142,7 @@ public class PostgresqlAdaptor extends BaseDbAdaptor { case "CHARACTER": case "UUID": default: - return FieldType.dimension; + return FieldType.categorical; } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/RuleSemanticModeller.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/RuleSemanticModeller.java index b27c25169..324c0cb61 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/RuleSemanticModeller.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/modeller/RuleSemanticModeller.java @@ -23,7 +23,7 @@ public class RuleSemanticModeller implements SemanticModeller { private ColumnSchema convert(DBColumn dbColumn) { ColumnSchema columnSchema = new ColumnSchema(); - columnSchema.setName(dbColumn.getComment()); + columnSchema.setName(dbColumn.getColumnName()); columnSchema.setColumnName(dbColumn.getColumnName()); columnSchema.setComment(dbColumn.getComment()); columnSchema.setDataType(dbColumn.getDataType()); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java index 8fe99042c..8abff2ba5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java @@ -79,7 +79,7 @@ public class DataSetServiceImpl extends ServiceImpl DataSetDO dataSetDO = convert(dataSetReq); dataSetDO.setStatus(StatusEnum.ONLINE.getCode()); DataSetResp dataSetResp = convert(dataSetDO); - conflictCheck(dataSetResp); + // conflictCheck(dataSetResp); save(dataSetDO); dataSetResp.setId(dataSetDO.getId()); return dataSetResp; @@ -90,7 +90,7 @@ public class DataSetServiceImpl extends ServiceImpl dataSetReq.updatedBy(user.getName()); DataSetDO dataSetDO = convert(dataSetReq); DataSetResp dataSetResp = convert(dataSetDO); - conflictCheck(dataSetResp); + // conflictCheck(dataSetResp); updateById(dataSetDO); return dataSetResp; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java index 976d90fbd..a686d8345 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java @@ -225,6 +225,9 @@ public class DatabaseServiceImpl extends ServiceImpl columns = getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql()); + DatabaseResp databaseResp = getDatabase(modelBuildReq.getDatabaseId()); + DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType()); + columns.forEach(c -> c.setFieldType(engineAdaptor.classifyColumnType(c.getDataType()))); dbColumnMap.put(modelBuildReq.getSql(), columns); } else { for (String table : modelBuildReq.getTables()) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java index d004c9a3b..88c78a2a4 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java @@ -165,8 +165,8 @@ public class ModelConverter { public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq, String tableName) { ModelReq modelReq = new ModelReq(); - modelReq.setName(modelSchema.getName()); - modelReq.setBizName(modelSchema.getBizName()); + modelReq.setName(modelBuildReq.getName()); + modelReq.setBizName(modelBuildReq.getBizName()); modelReq.setDatabaseId(modelBuildReq.getDatabaseId()); modelReq.setDomainId(modelBuildReq.getDomainId()); ModelDetail modelDetail = new ModelDetail(); @@ -198,10 +198,12 @@ public class ModelConverter { } private static IdentifyType getIdentifyType(FieldType fieldType) { - if (FieldType.foreign_key.equals(fieldType) || FieldType.primary_key.equals(fieldType)) { + if (FieldType.primary_key.equals(fieldType)) { return IdentifyType.primary; - } else { + } else if (FieldType.foreign_key.equals(fieldType)) { return IdentifyType.foreign; + } else { + return null; } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SemanticModellerTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SemanticModellerTest.java index a8793fb91..b66c58f55 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SemanticModellerTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SemanticModellerTest.java @@ -42,7 +42,7 @@ public class SemanticModellerTest extends BaseTest { Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size()); Assertions.assertEquals(FieldType.primary_key, userModelSchema.getColumnByName("user_name").getFiledType()); - Assertions.assertEquals(FieldType.dimension, + Assertions.assertEquals(FieldType.categorical, userModelSchema.getColumnByName("department").getFiledType()); ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis"); @@ -51,7 +51,7 @@ public class SemanticModellerTest extends BaseTest { stayTimeModelSchema.getColumnByName("user_name").getFiledType()); Assertions.assertEquals(FieldType.partition_time, stayTimeModelSchema.getColumnByName("imp_date").getFiledType()); - Assertions.assertEquals(FieldType.dimension, + Assertions.assertEquals(FieldType.categorical, stayTimeModelSchema.getColumnByName("page").getFiledType()); Assertions.assertEquals(FieldType.measure, stayTimeModelSchema.getColumnByName("stay_hours").getFiledType()); @@ -75,9 +75,9 @@ public class SemanticModellerTest extends BaseTest { Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size()); Assertions.assertEquals(FieldType.partition_time, pvModelSchema.getColumnByName("imp_date").getFiledType()); - Assertions.assertEquals(FieldType.dimension, + Assertions.assertEquals(FieldType.categorical, pvModelSchema.getColumnByName("user_name").getFiledType()); - Assertions.assertEquals(FieldType.dimension, + Assertions.assertEquals(FieldType.categorical, pvModelSchema.getColumnByName("page").getFiledType()); Assertions.assertEquals(FieldType.measure, pvModelSchema.getColumnByName("pv").getFiledType());