[Fix][launcher]Fix a number of issues related to semantic modeling.

This commit is contained in:
jerryjzhang
2024-12-02 21:38:15 +08:00
parent cf79ac9ece
commit 0ce79cbfc0
13 changed files with 39 additions and 18 deletions

View File

@@ -91,6 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
// mapModes // mapModes
Set<Long> requestedDatasets = queryNLReq.getDataSetIds(); Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
List<SemanticParseInfo> candidateParses = Lists.newArrayList(); List<SemanticParseInfo> candidateParses = Lists.newArrayList();
StringBuilder errMsg = new StringBuilder();
for (Long datasetId : requestedDatasets) { for (Long datasetId : requestedDatasets) {
queryNLReq.setDataSetIds(Collections.singleton(datasetId)); queryNLReq.setDataSetIds(Collections.singleton(datasetId));
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId()); ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
@@ -104,6 +105,7 @@ public class NL2SQLParser implements ChatQueryParser {
doParse(queryNLReq, parseResp); doParse(queryNLReq, parseResp);
} }
if (parseResp.getSelectedParses().isEmpty()) { if (parseResp.getSelectedParses().isEmpty()) {
errMsg.append(parseResp.getErrorMsg());
continue; continue;
} }
// for one dataset select the top 1 parse after sorting // for one dataset select the top 1 parse after sorting
@@ -116,6 +118,10 @@ public class NL2SQLParser implements ChatQueryParser {
SemanticParseInfo.sort(candidateParses); SemanticParseInfo.sort(candidateParses);
parseContext.getResponse().setSelectedParses( parseContext.getResponse().setSelectedParses(
candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size()))); candidateParses.subList(0, Math.min(parserShowCount, candidateParses.size())));
if (parseContext.getResponse().getSelectedParses().isEmpty()) {
parseContext.getResponse().setState(ParseResp.ParseState.FAILED);
parseContext.getResponse().setErrorMsg(errMsg.toString());
}
} }
// next go with llm-based parsers unless LLM is disabled or use feedback is needed. // next go with llm-based parsers unless LLM is disabled or use feedback is needed.

View File

@@ -15,7 +15,7 @@ public class ColumnSchema {
private FieldType filedType; private FieldType filedType;
private AggOperatorEnum agg; private AggOperatorEnum agg = AggOperatorEnum.UNKNOWN;
private String name; private String name;

View File

@@ -1,5 +1,5 @@
package com.tencent.supersonic.headless.api.pojo.enums; package com.tencent.supersonic.headless.api.pojo.enums;
public enum FieldType { public enum FieldType {
primary_key, foreign_key, partition_time, time, dimension, measure; primary_key, foreign_key, partition_time, time, categorical, measure;
} }

View File

@@ -9,6 +9,10 @@ import java.util.List;
@Data @Data
public class ModelBuildReq { public class ModelBuildReq {
private String name;
private String bizName;
private Long databaseId; private Long databaseId;
private Long domainId; private Long domainId;

View File

@@ -211,7 +211,9 @@ public class QueryStructReq extends SemanticQueryReq {
SelectItem selectExpressionItem = new SelectItem(function); SelectItem selectExpressionItem = new SelectItem(function);
String alias = String alias =
StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName; StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() : columnName;
selectExpressionItem.setAlias(new Alias(alias)); if (!alias.equals(columnName)) {
selectExpressionItem.setAlias(new Alias(alias));
}
return selectExpressionItem; return selectExpressionItem;
} }

View File

@@ -84,7 +84,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
return connection.getMetaData(); return connection.getMetaData();
} }
protected static FieldType classifyColumnType(String typeName) { public FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) { switch (typeName.toUpperCase()) {
case "INT": case "INT":
case "INTEGER": case "INTEGER":
@@ -101,7 +101,7 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
case "TIMESTAMP": case "TIMESTAMP":
return FieldType.time; return FieldType.time;
default: default:
return FieldType.dimension; return FieldType.categorical;
} }
} }

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.core.adaptor.db; package com.tencent.supersonic.headless.core.adaptor.db;
import com.tencent.supersonic.headless.api.pojo.DBColumn; import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import java.sql.SQLException; import java.sql.SQLException;
@@ -19,4 +20,6 @@ public interface DbAdaptor {
List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName) List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
throws SQLException; throws SQLException;
FieldType classifyColumnType(String typeName);
} }

View File

@@ -114,7 +114,8 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
return dbColumns; return dbColumns;
} }
protected static FieldType classifyColumnType(String typeName) { @Override
public FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) { switch (typeName.toUpperCase()) {
case "INT": case "INT":
case "INTEGER": case "INTEGER":
@@ -141,7 +142,7 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
case "CHARACTER": case "CHARACTER":
case "UUID": case "UUID":
default: default:
return FieldType.dimension; return FieldType.categorical;
} }
} }

View File

@@ -23,7 +23,7 @@ public class RuleSemanticModeller implements SemanticModeller {
private ColumnSchema convert(DBColumn dbColumn) { private ColumnSchema convert(DBColumn dbColumn) {
ColumnSchema columnSchema = new ColumnSchema(); ColumnSchema columnSchema = new ColumnSchema();
columnSchema.setName(dbColumn.getComment()); columnSchema.setName(dbColumn.getColumnName());
columnSchema.setColumnName(dbColumn.getColumnName()); columnSchema.setColumnName(dbColumn.getColumnName());
columnSchema.setComment(dbColumn.getComment()); columnSchema.setComment(dbColumn.getComment());
columnSchema.setDataType(dbColumn.getDataType()); columnSchema.setDataType(dbColumn.getDataType());

View File

@@ -79,7 +79,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
DataSetDO dataSetDO = convert(dataSetReq); DataSetDO dataSetDO = convert(dataSetReq);
dataSetDO.setStatus(StatusEnum.ONLINE.getCode()); dataSetDO.setStatus(StatusEnum.ONLINE.getCode());
DataSetResp dataSetResp = convert(dataSetDO); DataSetResp dataSetResp = convert(dataSetDO);
conflictCheck(dataSetResp); // conflictCheck(dataSetResp);
save(dataSetDO); save(dataSetDO);
dataSetResp.setId(dataSetDO.getId()); dataSetResp.setId(dataSetDO.getId());
return dataSetResp; return dataSetResp;
@@ -90,7 +90,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
dataSetReq.updatedBy(user.getName()); dataSetReq.updatedBy(user.getName());
DataSetDO dataSetDO = convert(dataSetReq); DataSetDO dataSetDO = convert(dataSetReq);
DataSetResp dataSetResp = convert(dataSetDO); DataSetResp dataSetResp = convert(dataSetDO);
conflictCheck(dataSetResp); // conflictCheck(dataSetResp);
updateById(dataSetDO); updateById(dataSetDO);
return dataSetResp; return dataSetResp;
} }

View File

@@ -225,6 +225,9 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
if (StringUtils.isNotBlank(modelBuildReq.getSql())) { if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
List<DBColumn> columns = List<DBColumn> columns =
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql()); getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getSql());
DatabaseResp databaseResp = getDatabase(modelBuildReq.getDatabaseId());
DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
columns.forEach(c -> c.setFieldType(engineAdaptor.classifyColumnType(c.getDataType())));
dbColumnMap.put(modelBuildReq.getSql(), columns); dbColumnMap.put(modelBuildReq.getSql(), columns);
} else { } else {
for (String table : modelBuildReq.getTables()) { for (String table : modelBuildReq.getTables()) {

View File

@@ -165,8 +165,8 @@ public class ModelConverter {
public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq, public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq,
String tableName) { String tableName) {
ModelReq modelReq = new ModelReq(); ModelReq modelReq = new ModelReq();
modelReq.setName(modelSchema.getName()); modelReq.setName(modelBuildReq.getName());
modelReq.setBizName(modelSchema.getBizName()); modelReq.setBizName(modelBuildReq.getBizName());
modelReq.setDatabaseId(modelBuildReq.getDatabaseId()); modelReq.setDatabaseId(modelBuildReq.getDatabaseId());
modelReq.setDomainId(modelBuildReq.getDomainId()); modelReq.setDomainId(modelBuildReq.getDomainId());
ModelDetail modelDetail = new ModelDetail(); ModelDetail modelDetail = new ModelDetail();
@@ -198,10 +198,12 @@ public class ModelConverter {
} }
private static IdentifyType getIdentifyType(FieldType fieldType) { private static IdentifyType getIdentifyType(FieldType fieldType) {
if (FieldType.foreign_key.equals(fieldType) || FieldType.primary_key.equals(fieldType)) { if (FieldType.primary_key.equals(fieldType)) {
return IdentifyType.primary; return IdentifyType.primary;
} else { } else if (FieldType.foreign_key.equals(fieldType)) {
return IdentifyType.foreign; return IdentifyType.foreign;
} else {
return null;
} }
} }

View File

@@ -42,7 +42,7 @@ public class SemanticModellerTest extends BaseTest {
Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size()); Assertions.assertEquals(2, userModelSchema.getColumnSchemas().size());
Assertions.assertEquals(FieldType.primary_key, Assertions.assertEquals(FieldType.primary_key,
userModelSchema.getColumnByName("user_name").getFiledType()); userModelSchema.getColumnByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.dimension, Assertions.assertEquals(FieldType.categorical,
userModelSchema.getColumnByName("department").getFiledType()); userModelSchema.getColumnByName("department").getFiledType());
ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis"); ModelSchema stayTimeModelSchema = modelSchemaMap.get("s2_stay_time_statis");
@@ -51,7 +51,7 @@ public class SemanticModellerTest extends BaseTest {
stayTimeModelSchema.getColumnByName("user_name").getFiledType()); stayTimeModelSchema.getColumnByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.partition_time, Assertions.assertEquals(FieldType.partition_time,
stayTimeModelSchema.getColumnByName("imp_date").getFiledType()); stayTimeModelSchema.getColumnByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension, Assertions.assertEquals(FieldType.categorical,
stayTimeModelSchema.getColumnByName("page").getFiledType()); stayTimeModelSchema.getColumnByName("page").getFiledType());
Assertions.assertEquals(FieldType.measure, Assertions.assertEquals(FieldType.measure,
stayTimeModelSchema.getColumnByName("stay_hours").getFiledType()); stayTimeModelSchema.getColumnByName("stay_hours").getFiledType());
@@ -75,9 +75,9 @@ public class SemanticModellerTest extends BaseTest {
Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size()); Assertions.assertEquals(5, pvModelSchema.getColumnSchemas().size());
Assertions.assertEquals(FieldType.partition_time, Assertions.assertEquals(FieldType.partition_time,
pvModelSchema.getColumnByName("imp_date").getFiledType()); pvModelSchema.getColumnByName("imp_date").getFiledType());
Assertions.assertEquals(FieldType.dimension, Assertions.assertEquals(FieldType.categorical,
pvModelSchema.getColumnByName("user_name").getFiledType()); pvModelSchema.getColumnByName("user_name").getFiledType());
Assertions.assertEquals(FieldType.dimension, Assertions.assertEquals(FieldType.categorical,
pvModelSchema.getColumnByName("page").getFiledType()); pvModelSchema.getColumnByName("page").getFiledType());
Assertions.assertEquals(FieldType.measure, Assertions.assertEquals(FieldType.measure,
pvModelSchema.getColumnByName("pv").getFiledType()); pvModelSchema.getColumnByName("pv").getFiledType());