mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[Fix][launcher]Fix a number of issues related to semantic modeling.
This commit is contained in:
@@ -91,6 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
// mapModes
|
// mapModes
|
||||||
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
|
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
|
||||||
List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
||||||
|
StringBuilder errMsg = new StringBuilder();
|
||||||
for (Long datasetId : requestedDatasets) {
|
for (Long datasetId : requestedDatasets) {
|
||||||
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
|
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
|
||||||
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
|
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
|
||||||
@@ -104,6 +105,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
doParse(queryNLReq, parseResp);
|
doParse(queryNLReq, parseResp);
|
||||||
}
|
}
|
||||||
if (parseResp.getSelectedParses().isEmpty()) {
|
if (parseResp.getSelectedParses().isEmpty()) {
|
||||||
|
errMsg.append(parseResp.getErrorMsg());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// for one dataset select the top 1 parse after sorting
|
// for one dataset select the top 1 parse after sorting
|
||||||
@@ -116,6 +118,10 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
SemanticParseInfo.sort(candidateParses);
|
SemanticParseInfo.sort(candidateParses);
|
||||||
parseContext.getResponse().setSelectedParses(
|
parseContext.getResponse().setSelectedParses(
|
||||||
candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size())));
|
candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size())));
|
||||||
|
if (parseContext.getResponse().getSelectedParses().isEmpty()) {
|
||||||
|
parseContext.getResponse().setState(ParseResp.ParseState.FAILED);
|
||||||
|
parseContext.getResponse().setErrorMsg(errMsg.toString());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
|
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ public class ColumnSchema {
|
|||||||
|
|
||||||
private FieldType filedType;
|
private FieldType filedType;
|
||||||
|
|
||||||
private AggOperatorEnum agg;
|
private AggOperatorEnum agg = AggOperatorEnum.UNKNOWN;
|
||||||
|
|
||||||
private String name;
|
private String name;
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||||
|
|
||||||
public enum FieldType {
|
public enum FieldType {
|
||||||
primary_key, foreign_key, partition_time, time, dimension, measure;
|
primary_key, foreign_key, partition_time, time, categorical, measure;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ import java.util.List;
|
|||||||
@Data
|
@Data
|
||||||
public class ModelBuildReq {
|
public class ModelBuildReq {
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private String bizName;
|
||||||
|
|
||||||
private Long databaseId;
|
private Long databaseId;
|
||||||
|
|
||||||
private Long domainId;
|
private Long domainId;
|
||||||
|
|||||||
@@ -211,7 +211,9 @@ public class QueryStructReq extends SemanticQueryReq {
|
|||||||
SelectItem selectExpressionItem = new SelectItem(function);
|
SelectItem selectExpressionItem = new SelectItem(function);
|
||||||
String alias =
|
String alias =
|
||||||
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
|
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
|
||||||
selectExpressionItem.setAlias(new Alias(alias));
|
if (!alias.equals(columnName)) {
|
||||||
|
selectExpressionItem.setAlias(new Alias(alias));
|
||||||
|
}
|
||||||
return selectExpressionItem;
|
return selectExpressionItem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
|||||||
return connection.getMetaData();
|
return connection.getMetaData();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected static FieldType classifyColumnType(String typeName) {
|
public FieldType classifyColumnType(String typeName) {
|
||||||
switch (typeName.toUpperCase()) {
|
switch (typeName.toUpperCase()) {
|
||||||
case "INT":
|
case "INT":
|
||||||
case "INTEGER":
|
case "INTEGER":
|
||||||
@@ -101,7 +101,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
|||||||
case "TIMESTAMP":
|
case "TIMESTAMP":
|
||||||
return FieldType.time;
|
return FieldType.time;
|
||||||
default:
|
default:
|
||||||
return FieldType.dimension;
|
return FieldType.categorical;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.core.adaptor.db;
|
package com.tencent.supersonic.headless.core.adaptor.db;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||||
|
|
||||||
import java.sql.SQLException;
|
import java.sql.SQLException;
|
||||||
@@ -19,4 +20,6 @@ public interface DbAdaptor {
|
|||||||
|
|
||||||
List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
|
List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
|
||||||
throws SQLException;
|
throws SQLException;
|
||||||
|
|
||||||
|
FieldType classifyColumnType(String typeName);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,7 +114,8 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
|
|||||||
return dbColumns;
|
return dbColumns;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected static FieldType classifyColumnType(String typeName) {
|
@Override
|
||||||
|
public FieldType classifyColumnType(String typeName) {
|
||||||
switch (typeName.toUpperCase()) {
|
switch (typeName.toUpperCase()) {
|
||||||
case "INT":
|
case "INT":
|
||||||
case "INTEGER":
|
case "INTEGER":
|
||||||
@@ -141,7 +142,7 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
|
|||||||
case "CHARACTER":
|
case "CHARACTER":
|
||||||
case "UUID":
|
case "UUID":
|
||||||
default:
|
default:
|
||||||
return FieldType.dimension;
|
return FieldType.categorical;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ public class RuleSemanticModeller implements SemanticModeller {
|
|||||||
|
|
||||||
private ColumnSchema convert(DBColumn dbColumn) {
|
private ColumnSchema convert(DBColumn dbColumn) {
|
||||||
ColumnSchema columnSchema = new ColumnSchema();
|
ColumnSchema columnSchema = new ColumnSchema();
|
||||||
columnSchema.setName(dbColumn.getComment());
|
columnSchema.setName(dbColumn.getColumnName());
|
||||||
columnSchema.setColumnName(dbColumn.getColumnName());
|
columnSchema.setColumnName(dbColumn.getColumnName());
|
||||||
columnSchema.setComment(dbColumn.getComment());
|
columnSchema.setComment(dbColumn.getComment());
|
||||||
columnSchema.setDataType(dbColumn.getDataType());
|
columnSchema.setDataType(dbColumn.getDataType());
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
|
|||||||
DataSetDO dataSetDO = convert(dataSetReq);
|
DataSetDO dataSetDO = convert(dataSetReq);
|
||||||
dataSetDO.setStatus(StatusEnum.ONLINE.getCode());
|
dataSetDO.setStatus(StatusEnum.ONLINE.getCode());
|
||||||
DataSetResp dataSetResp = convert(dataSetDO);
|
DataSetResp dataSetResp = convert(dataSetDO);
|
||||||
conflictCheck(dataSetResp);
|
// conflictCheck(dataSetResp);
|
||||||
save(dataSetDO);
|
save(dataSetDO);
|
||||||
dataSetResp.setId(dataSetDO.getId());
|
dataSetResp.setId(dataSetDO.getId());
|
||||||
return dataSetResp;
|
return dataSetResp;
|
||||||
@@ -90,7 +90,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
|
|||||||
dataSetReq.updatedBy(user.getName());
|
dataSetReq.updatedBy(user.getName());
|
||||||
DataSetDO dataSetDO = convert(dataSetReq);
|
DataSetDO dataSetDO = convert(dataSetReq);
|
||||||
DataSetResp dataSetResp = convert(dataSetDO);
|
DataSetResp dataSetResp = convert(dataSetDO);
|
||||||
conflictCheck(dataSetResp);
|
// conflictCheck(dataSetResp);
|
||||||
updateById(dataSetDO);
|
updateById(dataSetDO);
|
||||||
return dataSetResp;
|
return dataSetResp;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -225,6 +225,9 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
|||||||
if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
|
if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
|
||||||
List<DBColumn> columns =
|
List<DBColumn> columns =
|
||||||
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
|
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
|
||||||
|
DatabaseResp databaseResp = getDatabase(modelBuildReq.getDatabaseId());
|
||||||
|
DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
|
||||||
|
columns.forEach(c -> c.setFieldType(engineAdaptor.classifyColumnType(c.getDataType())));
|
||||||
dbColumnMap.put(modelBuildReq.getSql(), columns);
|
dbColumnMap.put(modelBuildReq.getSql(), columns);
|
||||||
} else {
|
} else {
|
||||||
for (String table : modelBuildReq.getTables()) {
|
for (String table : modelBuildReq.getTables()) {
|
||||||
|
|||||||
@@ -165,8 +165,8 @@ public class ModelConverter {
|
|||||||
public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq,
|
public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq,
|
||||||
String tableName) {
|
String tableName) {
|
||||||
ModelReq modelReq = new ModelReq();
|
ModelReq modelReq = new ModelReq();
|
||||||
modelReq.setName(modelSchema.getName());
|
modelReq.setName(modelBuildReq.getName());
|
||||||
modelReq.setBizName(modelSchema.getBizName());
|
modelReq.setBizName(modelBuildReq.getBizName());
|
||||||
modelReq.setDatabaseId(modelBuildReq.getDatabaseId());
|
modelReq.setDatabaseId(modelBuildReq.getDatabaseId());
|
||||||
modelReq.setDomainId(modelBuildReq.getDomainId());
|
modelReq.setDomainId(modelBuildReq.getDomainId());
|
||||||
ModelDetail modelDetail = new ModelDetail();
|
ModelDetail modelDetail = new ModelDetail();
|
||||||
@@ -198,10 +198,12 @@ public class ModelConverter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static IdentifyType getIdentifyType(FieldType fieldType) {
|
private static IdentifyType getIdentifyType(FieldType fieldType) {
|
||||||
if (FieldType.foreign_key.equals(fieldType) || FieldType.primary_key.equals(fieldType)) {
|
if (FieldType.primary_key.equals(fieldType)) {
|
||||||
return IdentifyType.primary;
|
return IdentifyType.primary;
|
||||||
} else {
|
} else if (FieldType.foreign_key.equals(fieldType)) {
|
||||||
return IdentifyType.foreign;
|
return IdentifyType.foreign;
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ public class SemanticModellerTest extends BaseTest {
|
|||||||
Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size());
|
Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size());
|
||||||
Assertions.assertEquals(FieldType.primary_key,
|
Assertions.assertEquals(FieldType.primary_key,
|
||||||
userModelSchema.getColumnByName("user_name").getFiledType());
|
userModelSchema.getColumnByName("user_name").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.dimension,
|
Assertions.assertEquals(FieldType.categorical,
|
||||||
userModelSchema.getColumnByName("department").getFiledType());
|
userModelSchema.getColumnByName("department").getFiledType());
|
||||||
|
|
||||||
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
|
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
|
||||||
@@ -51,7 +51,7 @@ public class SemanticModellerTest extends BaseTest {
|
|||||||
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
|
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.partition_time,
|
Assertions.assertEquals(FieldType.partition_time,
|
||||||
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
|
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.dimension,
|
Assertions.assertEquals(FieldType.categorical,
|
||||||
stayTimeModelSchema.getColumnByName("page").getFiledType());
|
stayTimeModelSchema.getColumnByName("page").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.measure,
|
Assertions.assertEquals(FieldType.measure,
|
||||||
stayTimeModelSchema.getColumnByName("stay_hours").getFiledType());
|
stayTimeModelSchema.getColumnByName("stay_hours").getFiledType());
|
||||||
@@ -75,9 +75,9 @@ public class SemanticModellerTest extends BaseTest {
|
|||||||
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
|
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
|
||||||
Assertions.assertEquals(FieldType.partition_time,
|
Assertions.assertEquals(FieldType.partition_time,
|
||||||
pvModelSchema.getColumnByName("imp_date").getFiledType());
|
pvModelSchema.getColumnByName("imp_date").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.dimension,
|
Assertions.assertEquals(FieldType.categorical,
|
||||||
pvModelSchema.getColumnByName("user_name").getFiledType());
|
pvModelSchema.getColumnByName("user_name").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.dimension,
|
Assertions.assertEquals(FieldType.categorical,
|
||||||
pvModelSchema.getColumnByName("page").getFiledType());
|
pvModelSchema.getColumnByName("page").getFiledType());
|
||||||
Assertions.assertEquals(FieldType.measure,
|
Assertions.assertEquals(FieldType.measure,
|
||||||
pvModelSchema.getColumnByName("pv").getFiledType());
|
pvModelSchema.getColumnByName("pv").getFiledType());
|
||||||
|
|||||||
Reference in New Issue
Block a user