[Fix][launcher]Fix a number of issues related to semantic modeling.

This commit is contained in:
jerryjzhang
2024-12-02 21:38:15 +08:00
parent cf79ac9ece
commit 0ce79cbfc0
13 changed files with 39 additions and 18 deletions

View File

@@ -91,6 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
// mapModes
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
List<SemanticParseInfo> candidateParses = Lists.newArrayList();
StringBuilder errMsg = new StringBuilder();
for (Long datasetId : requestedDatasets) {
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
@@ -104,6 +105,7 @@ public class NL2SQLParser implements ChatQueryParser {
doParse(queryNLReq, parseResp);
}
if (parseResp.getSelectedParses().isEmpty()) {
errMsg.append(parseResp.getErrorMsg());
continue;
}
// for one dataset select the top 1 parse after sorting
@@ -116,6 +118,10 @@ public class NL2SQLParser implements ChatQueryParser {
SemanticParseInfo.sort(candidateParses);
parseContext.getResponse().setSelectedParses(
candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size())));
if (parseContext.getResponse().getSelectedParses().isEmpty()) {
parseContext.getResponse().setState(ParseResp.ParseState.FAILED);
parseContext.getResponse().setErrorMsg(errMsg.toString());
}
}
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.

View File

@@ -15,7 +15,7 @@ public class ColumnSchema {
private FieldType filedType;
private AggOperatorEnum agg;
private AggOperatorEnum agg = AggOperatorEnum.UNKNOWN;
private String name;

View File

@@ -1,5 +1,5 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum FieldType {
primary_key, foreign_key, partition_time, time, dimension, measure;
primary_key, foreign_key, partition_time, time, categorical, measure;
}

View File

@@ -9,6 +9,10 @@ import java.util.List;
@Data
public class ModelBuildReq {
private String name;
private String bizName;
private Long databaseId;
private Long domainId;

View File

@@ -211,7 +211,9 @@ public class QueryStructReq extends SemanticQueryReq {
SelectItem selectExpressionItem = new SelectItem(function);
String alias =
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
selectExpressionItem.setAlias(new Alias(alias));
if (!alias.equals(columnName)) {
selectExpressionItem.setAlias(new Alias(alias));
}
return selectExpressionItem;
}

View File

@@ -84,7 +84,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
return connection.getMetaData();
}
protected static FieldType classifyColumnType(String typeName) {
public FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) {
case "INT":
case "INTEGER":
@@ -101,7 +101,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
case "TIMESTAMP":
return FieldType.time;
default:
return FieldType.dimension;
return FieldType.categorical;
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.core.adaptor.db;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import java.sql.SQLException;
@@ -19,4 +20,6 @@ public interface DbAdaptor {
List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
throws SQLException;
FieldType classifyColumnType(String typeName);
}

View File

@@ -114,7 +114,8 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
return dbColumns;
}
protected static FieldType classifyColumnType(String typeName) {
@Override
public FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) {
case "INT":
case "INTEGER":
@@ -141,7 +142,7 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
case "CHARACTER":
case "UUID":
default:
return FieldType.dimension;
return FieldType.categorical;
}
}

View File

@@ -23,7 +23,7 @@ public class RuleSemanticModeller implements SemanticModeller {
private ColumnSchema convert(DBColumn dbColumn) {
ColumnSchema columnSchema = new ColumnSchema();
columnSchema.setName(dbColumn.getComment());
columnSchema.setName(dbColumn.getColumnName());
columnSchema.setColumnName(dbColumn.getColumnName());
columnSchema.setComment(dbColumn.getComment());
columnSchema.setDataType(dbColumn.getDataType());

View File

@@ -79,7 +79,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
DataSetDO dataSetDO = convert(dataSetReq);
dataSetDO.setStatus(StatusEnum.ONLINE.getCode());
DataSetResp dataSetResp = convert(dataSetDO);
conflictCheck(dataSetResp);
// conflictCheck(dataSetResp);
save(dataSetDO);
dataSetResp.setId(dataSetDO.getId());
return dataSetResp;
@@ -90,7 +90,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
dataSetReq.updatedBy(user.getName());
DataSetDO dataSetDO = convert(dataSetReq);
DataSetResp dataSetResp = convert(dataSetDO);
conflictCheck(dataSetResp);
// conflictCheck(dataSetResp);
updateById(dataSetDO);
return dataSetResp;
}

View File

@@ -225,6 +225,9 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
List<DBColumn> columns =
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
DatabaseResp databaseResp = getDatabase(modelBuildReq.getDatabaseId());
DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
columns.forEach(c -> c.setFieldType(engineAdaptor.classifyColumnType(c.getDataType())));
dbColumnMap.put(modelBuildReq.getSql(), columns);
} else {
for (String table : modelBuildReq.getTables()) {

View File

@@ -165,8 +165,8 @@ public class ModelConverter {
public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq,
String tableName) {
ModelReq modelReq = new ModelReq();
modelReq.setName(modelSchema.getName());
modelReq.setBizName(modelSchema.getBizName());
modelReq.setName(modelBuildReq.getName());
modelReq.setBizName(modelBuildReq.getBizName());
modelReq.setDatabaseId(modelBuildReq.getDatabaseId());
modelReq.setDomainId(modelBuildReq.getDomainId());
ModelDetail modelDetail = new ModelDetail();
@@ -198,10 +198,12 @@ public class ModelConverter {
}
private static IdentifyType getIdentifyType(FieldType fieldType) {
if (FieldType.foreign_key.equals(fieldType) || FieldType.primary_key.equals(fieldType)) {
if (FieldType.primary_key.equals(fieldType)) {
return IdentifyType.primary;
} else {
} else if (FieldType.foreign_key.equals(fieldType)) {
return IdentifyType.foreign;
} else {
return null;
}
}

View File

@@ -42,7 +42,7 @@ public class SemanticModellerTest extends BaseTest {
Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size());
Assertions.assertEquals(FieldType.primary_key,
userModelSchema.getColumnByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.dimension,
Assertions.assertEquals(FieldType.categorical,
userModelSchema.getColumnByName("department").getFiledType());
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
@@ -51,7 +51,7 @@ public class SemanticModellerTest extends BaseTest {
stayTimeModelSchema.getColumnByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.partition_time,
stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension,
Assertions.assertEquals(FieldType.categorical,
stayTimeModelSchema.getColumnByName("page").getFiledType());
Assertions.assertEquals(FieldType.measure,
stayTimeModelSchema.getColumnByName("stay_hours").getFiledType());
@@ -75,9 +75,9 @@ public class SemanticModellerTest extends BaseTest {
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
Assertions.assertEquals(FieldType.partition_time,
pvModelSchema.getColumnByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension,
Assertions.assertEquals(FieldType.categorical,
pvModelSchema.getColumnByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.dimension,
Assertions.assertEquals(FieldType.categorical,
pvModelSchema.getColumnByName("page").getFiledType());
Assertions.assertEquals(FieldType.measure,
pvModelSchema.getColumnByName("pv").getFiledType());