mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
[Fix][launcher]Fix a number of issues related to semantic modeling.
This commit is contained in:
@@ -91,6 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
// mapModes
|
||||
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
|
||||
List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
||||
StringBuilder errMsg = new StringBuilder();
|
||||
for (Long datasetId : requestedDatasets) {
|
||||
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
|
||||
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
|
||||
@@ -104,6 +105,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
doParse(queryNLReq, parseResp);
|
||||
}
|
||||
if (parseResp.getSelectedParses().isEmpty()) {
|
||||
errMsg.append(parseResp.getErrorMsg());
|
||||
continue;
|
||||
}
|
||||
// for one dataset select the top 1 parse after sorting
|
||||
@@ -116,6 +118,10 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
SemanticParseInfo.sort(candidateParses);
|
||||
parseContext.getResponse().setSelectedParses(
|
||||
candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size())));
|
||||
if (parseContext.getResponse().getSelectedParses().isEmpty()) {
|
||||
parseContext.getResponse().setState(ParseResp.ParseState.FAILED);
|
||||
parseContext.getResponse().setErrorMsg(errMsg.toString());
|
||||
}
|
||||
}
|
||||
|
||||
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
|
||||
|
||||
@@ -15,7 +15,7 @@ public class ColumnSchema {
|
||||
|
||||
private FieldType filedType;
|
||||
|
||||
private AggOperatorEnum agg;
|
||||
private AggOperatorEnum agg = AggOperatorEnum.UNKNOWN;
|
||||
|
||||
private String name;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||
|
||||
public enum FieldType {
|
||||
primary_key, foreign_key, partition_time, time, dimension, measure;
|
||||
primary_key, foreign_key, partition_time, time, categorical, measure;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@ import java.util.List;
|
||||
@Data
|
||||
public class ModelBuildReq {
|
||||
|
||||
private String name;
|
||||
|
||||
private String bizName;
|
||||
|
||||
private Long databaseId;
|
||||
|
||||
private Long domainId;
|
||||
|
||||
@@ -211,7 +211,9 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
SelectItem selectExpressionItem = new SelectItem(function);
|
||||
String alias =
|
||||
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
|
||||
selectExpressionItem.setAlias(new Alias(alias));
|
||||
if (!alias.equals(columnName)) {
|
||||
selectExpressionItem.setAlias(new Alias(alias));
|
||||
}
|
||||
return selectExpressionItem;
|
||||
}
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
||||
return connection.getMetaData();
|
||||
}
|
||||
|
||||
protected static FieldType classifyColumnType(String typeName) {
|
||||
public FieldType classifyColumnType(String typeName) {
|
||||
switch (typeName.toUpperCase()) {
|
||||
case "INT":
|
||||
case "INTEGER":
|
||||
@@ -101,7 +101,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
||||
case "TIMESTAMP":
|
||||
return FieldType.time;
|
||||
default:
|
||||
return FieldType.dimension;
|
||||
return FieldType.categorical;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.core.adaptor.db;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||
|
||||
import java.sql.SQLException;
|
||||
@@ -19,4 +20,6 @@ public interface DbAdaptor {
|
||||
|
||||
List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
|
||||
throws SQLException;
|
||||
|
||||
FieldType classifyColumnType(String typeName);
|
||||
}
|
||||
|
||||
@@ -114,7 +114,8 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
|
||||
return dbColumns;
|
||||
}
|
||||
|
||||
protected static FieldType classifyColumnType(String typeName) {
|
||||
@Override
|
||||
public FieldType classifyColumnType(String typeName) {
|
||||
switch (typeName.toUpperCase()) {
|
||||
case "INT":
|
||||
case "INTEGER":
|
||||
@@ -141,7 +142,7 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
|
||||
case "CHARACTER":
|
||||
case "UUID":
|
||||
default:
|
||||
return FieldType.dimension;
|
||||
return FieldType.categorical;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ public class RuleSemanticModeller implements SemanticModeller {
|
||||
|
||||
private ColumnSchema convert(DBColumn dbColumn) {
|
||||
ColumnSchema columnSchema = new ColumnSchema();
|
||||
columnSchema.setName(dbColumn.getComment());
|
||||
columnSchema.setName(dbColumn.getColumnName());
|
||||
columnSchema.setColumnName(dbColumn.getColumnName());
|
||||
columnSchema.setComment(dbColumn.getComment());
|
||||
columnSchema.setDataType(dbColumn.getDataType());
|
||||
|
||||
@@ -79,7 +79,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
|
||||
DataSetDO dataSetDO = convert(dataSetReq);
|
||||
dataSetDO.setStatus(StatusEnum.ONLINE.getCode());
|
||||
DataSetResp dataSetResp = convert(dataSetDO);
|
||||
conflictCheck(dataSetResp);
|
||||
// conflictCheck(dataSetResp);
|
||||
save(dataSetDO);
|
||||
dataSetResp.setId(dataSetDO.getId());
|
||||
return dataSetResp;
|
||||
@@ -90,7 +90,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
|
||||
dataSetReq.updatedBy(user.getName());
|
||||
DataSetDO dataSetDO = convert(dataSetReq);
|
||||
DataSetResp dataSetResp = convert(dataSetDO);
|
||||
conflictCheck(dataSetResp);
|
||||
// conflictCheck(dataSetResp);
|
||||
updateById(dataSetDO);
|
||||
return dataSetResp;
|
||||
}
|
||||
|
||||
@@ -225,6 +225,9 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
||||
if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
|
||||
List<DBColumn> columns =
|
||||
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
|
||||
DatabaseResp databaseResp = getDatabase(modelBuildReq.getDatabaseId());
|
||||
DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
|
||||
columns.forEach(c -> c.setFieldType(engineAdaptor.classifyColumnType(c.getDataType())));
|
||||
dbColumnMap.put(modelBuildReq.getSql(), columns);
|
||||
} else {
|
||||
for (String table : modelBuildReq.getTables()) {
|
||||
|
||||
@@ -165,8 +165,8 @@ public class ModelConverter {
|
||||
public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq,
|
||||
String tableName) {
|
||||
ModelReq modelReq = new ModelReq();
|
||||
modelReq.setName(modelSchema.getName());
|
||||
modelReq.setBizName(modelSchema.getBizName());
|
||||
modelReq.setName(modelBuildReq.getName());
|
||||
modelReq.setBizName(modelBuildReq.getBizName());
|
||||
modelReq.setDatabaseId(modelBuildReq.getDatabaseId());
|
||||
modelReq.setDomainId(modelBuildReq.getDomainId());
|
||||
ModelDetail modelDetail = new ModelDetail();
|
||||
@@ -198,10 +198,12 @@ public class ModelConverter {
|
||||
}
|
||||
|
||||
private static IdentifyType getIdentifyType(FieldType fieldType) {
|
||||
if (FieldType.foreign_key.equals(fieldType) || FieldType.primary_key.equals(fieldType)) {
|
||||
if (FieldType.primary_key.equals(fieldType)) {
|
||||
return IdentifyType.primary;
|
||||
} else {
|
||||
} else if (FieldType.foreign_key.equals(fieldType)) {
|
||||
return IdentifyType.foreign;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ public class SemanticModellerTest extends BaseTest {
|
||||
Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.primary_key,
|
||||
userModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
Assertions.assertEquals(FieldType.categorical,
|
||||
userModelSchema.getColumnByName("department").getFiledType());
|
||||
|
||||
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
|
||||
@@ -51,7 +51,7 @@ public class SemanticModellerTest extends BaseTest {
|
||||
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.partition_time,
|
||||
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
Assertions.assertEquals(FieldType.categorical,
|
||||
stayTimeModelSchema.getColumnByName("page").getFiledType());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
stayTimeModelSchema.getColumnByName("stay_hours").getFiledType());
|
||||
@@ -75,9 +75,9 @@ public class SemanticModellerTest extends BaseTest {
|
||||
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
|
||||
Assertions.assertEquals(FieldType.partition_time,
|
||||
pvModelSchema.getColumnByName("imp_date").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
Assertions.assertEquals(FieldType.categorical,
|
||||
pvModelSchema.getColumnByName("user_name").getFiledType());
|
||||
Assertions.assertEquals(FieldType.dimension,
|
||||
Assertions.assertEquals(FieldType.categorical,
|
||||
pvModelSchema.getColumnByName("page").getFiledType());
|
||||
Assertions.assertEquals(FieldType.measure,
|
||||
pvModelSchema.getColumnByName("pv").getFiledType());
|
||||
|
||||
Reference in New Issue
Block a user