feat:Support kyuubi presto trino (#2109)

This commit is contained in:
zyclove
2025-02-26 17:33:14 +08:00
committed by GitHub
parent 11ff99cdbe
commit 5e3bafb953
31 changed files with 501 additions and 101 deletions

View File

@@ -10,7 +10,10 @@ public enum EngineType {
OTHER(7, "OTHER"), OTHER(7, "OTHER"),
DUCKDB(8, "DUCKDB"), DUCKDB(8, "DUCKDB"),
HANADB(9, "HANADB"), HANADB(9, "HANADB"),
STARROCKS(10, "STARROCKS"),; STARROCKS(10, "STARROCKS"),
KYUUBI(11, "KYUUBI"),
PRESTO(12, "PRESTO"),
TRINO(13, "TRINO"),;
private Integer code; private Integer code;

View File

@@ -11,6 +11,8 @@ import java.util.List;
@NoArgsConstructor @NoArgsConstructor
public class DbSchema { public class DbSchema {
private String catalog;
private String db; private String db;
private String table; private String table;

View File

@@ -8,7 +8,9 @@ import java.util.Set;
public enum DataType { public enum DataType {
MYSQL("mysql", "mysql", "com.mysql.cj.jdbc.Driver", "`", "`", "'", "'"), MYSQL("mysql", "mysql", "com.mysql.cj.jdbc.Driver", "`", "`", "'", "'"),
HIVE2("hive2", "hive", "org.apache.hive.jdbc.HiveDriver", "`", "`", "`", "`"), HIVE2("hive2", "hive", "org.apache.kyuubi.jdbc.KyuubiHiveDriver", "`", "`", "`", "`"),
KYUUBI("kyuubi", "kyuubi", "org.apache.kyuubi.jdbc.KyuubiHiveDriver", "`", "`", "`", "`"),
ORACLE("oracle", "oracle", "oracle.jdbc.driver.OracleDriver", "\"", "\"", "\"", "\""), ORACLE("oracle", "oracle", "oracle.jdbc.driver.OracleDriver", "\"", "\"", "\"", "\""),
@@ -27,6 +29,8 @@ public enum DataType {
PRESTO("presto", "presto", "com.facebook.presto.jdbc.PrestoDriver", "\"", "\"", "\"", "\""), PRESTO("presto", "presto", "com.facebook.presto.jdbc.PrestoDriver", "\"", "\"", "\"", "\""),
TRINO("trino", "trino", "io.trino.jdbc.TrinoDriver", "\"", "\"", "\"", "\""),
MOONBOX("moonbox", "moonbox", "moonbox.jdbc.MbDriver", "`", "`", "`", "`"), MOONBOX("moonbox", "moonbox", "moonbox.jdbc.MbDriver", "`", "`", "`", "`"),
CASSANDRA("cassandra", "cassandra", "com.github.adejanovski.cassandra.jdbc.CassandraDriver", "", CASSANDRA("cassandra", "cassandra", "com.github.adejanovski.cassandra.jdbc.CassandraDriver", "",
@@ -46,6 +50,7 @@ public enum DataType {
TDENGINE("TAOS", "TAOS", "com.taosdata.jdbc.TSDBDriver", "'", "'", "\"", "\""), TDENGINE("TAOS", "TAOS", "com.taosdata.jdbc.TSDBDriver", "'", "'", "\"", "\""),
POSTGRESQL("postgresql", "postgresql", "org.postgresql.Driver", "'", "'", "\"", "\""), POSTGRESQL("postgresql", "postgresql", "org.postgresql.Driver", "'", "'", "\"", "\""),
DUCKDB("duckdb", "duckdb", "org.duckdb.DuckDBDriver", "'", "'", "\"", "\""); DUCKDB("duckdb", "duckdb", "org.duckdb.DuckDBDriver", "'", "'", "\"", "\"");
private String feature; private String feature;

View File

@@ -19,6 +19,8 @@ public class ModelBuildReq {
private String sql; private String sql;
private String catalog;
private String db; private String db;
private List<String> tables; private List<String> tables;

View File

@@ -72,7 +72,8 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
// 1. Base detection // 1. Base detection
List<EmbeddingResult> baseResults = super.detect(chatQueryContext, terms, detectDataSetIds); List<EmbeddingResult> baseResults = super.detect(chatQueryContext, terms, detectDataSetIds);
boolean useLLM = Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM)); boolean useLLM =
Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM));
// 2. LLM enhanced detection // 2. LLM enhanced detection
if (useLLM) { if (useLLM) {
@@ -115,7 +116,8 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
* Extract valid word segments by filtering out unwanted word natures * Extract valid word segments by filtering out unwanted word natures
*/ */
private Set<String> extractValidSegments(String text) { private Set<String> extractValidSegments(String text) {
List<String> natureList = Arrays.asList(StringUtils.split(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE ), ",")); List<String> natureList = Arrays.asList(StringUtils.split(
mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE), ","));
return HanlpHelper.getSegment().seg(text).stream() return HanlpHelper.getSegment().seg(text).stream()
.filter(t -> natureList.stream().noneMatch(nature -> t.nature.startsWith(nature))) .filter(t -> natureList.stream().noneMatch(nature -> t.nature.startsWith(nature)))
.map(Term::getWord).collect(Collectors.toSet()); .map(Term::getWord).collect(Collectors.toSet());

View File

@@ -61,7 +61,8 @@ public class MapFilter {
List<SchemaElementMatch> value = entry.getValue(); List<SchemaElementMatch> value = entry.getValue();
if (!CollectionUtils.isEmpty(value)) { if (!CollectionUtils.isEmpty(value)) {
value.removeIf(schemaElementMatch -> StringUtils value.removeIf(schemaElementMatch -> StringUtils
.length(schemaElementMatch.getDetectWord()) <= 1 && !schemaElementMatch.isLlmMatched()); .length(schemaElementMatch.getDetectWord()) <= 1
&& !schemaElementMatch.isLlmMatched());
} }
} }
} }
@@ -80,7 +81,7 @@ public class MapFilter {
} }
public static void filterByQueryDataType(ChatQueryContext chatQueryContext, public static void filterByQueryDataType(ChatQueryContext chatQueryContext,
Predicate<SchemaElement> needRemovePredicate) { Predicate<SchemaElement> needRemovePredicate) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
chatQueryContext.getMapInfo().getDataSetElementMatches(); chatQueryContext.getMapInfo().getDataSetElementMatches();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) { for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {

View File

@@ -63,6 +63,6 @@ public class MapperConfig extends ParameterConfig {
"embedding的结果再通过一次LLM来筛选这时候忽略各个向量阀值", "bool", "Mapper相关配置"); "embedding的结果再通过一次LLM来筛选这时候忽略各个向量阀值", "bool", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE = public static final Parameter EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE =
new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']", "使用LLM召回二次处理时对问题分词词性的控制", new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']",
"分词后允许的词性才会进行向量召回", "list", "Mapper相关配置"); "使用LLM召回二次处理时对问题分词词性的控制", "分词后允许的词性才会进行向量召回", "list", "Mapper相关配置");
} }

View File

@@ -121,6 +121,18 @@
<artifactId>DmJdbcDriver18</artifactId> <artifactId>DmJdbcDriver18</artifactId>
<version>8.1.2.192</version> <version>8.1.2.192</version>
</dependency> </dependency>
<dependency>
<groupId>org.apache.kyuubi</groupId>
<artifactId>kyuubi-hive-jdbc</artifactId>
</dependency>
<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-jdbc</artifactId>
</dependency>
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-jdbc</artifactId>
</dependency>
</dependencies> </dependencies>

View File

@@ -5,23 +5,27 @@ import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType; import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.sql.Connection; import java.sql.*;
import java.sql.DatabaseMetaData;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Properties;
@Slf4j @Slf4j
public abstract class BaseDbAdaptor implements DbAdaptor { public abstract class BaseDbAdaptor implements DbAdaptor {
@Override @Override
public List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException { public List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException {
// Apart from supporting multiple catalog types of data sources, other types will return an List<String> catalogs = Lists.newArrayList();
// empty set by default. try (Connection con = getConnection(connectInfo);
return List.of(); Statement st = con.createStatement();
ResultSet rs = st.executeQuery("SHOW CATALOGS")) {
while (rs.next()) {
catalogs.add(rs.getString(1));
}
}
return catalogs;
} }
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException { public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
@@ -32,38 +36,49 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
protected List<String> getDBs(ConnectInfo connectionInfo) throws SQLException { protected List<String> getDBs(ConnectInfo connectionInfo) throws SQLException {
List<String> dbs = Lists.newArrayList(); List<String> dbs = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectionInfo);
try { try {
ResultSet schemaSet = metaData.getSchemas(); try (ResultSet schemaSet = getDatabaseMetaData(connectionInfo).getSchemas()) {
while (schemaSet.next()) { while (schemaSet.next()) {
String db = schemaSet.getString("TABLE_SCHEM"); String db = schemaSet.getString("TABLE_SCHEM");
dbs.add(db); dbs.add(db);
}
} }
} catch (Exception e) { } catch (Exception e) {
log.info("get meta schemas failed, try to get catalogs"); log.warn("get meta schemas failed", e);
log.warn("get meta schemas failed, try to get catalogs");
} }
try { try {
ResultSet catalogSet = metaData.getCatalogs(); try (ResultSet catalogSet = getDatabaseMetaData(connectionInfo).getCatalogs()) {
while (catalogSet.next()) { while (catalogSet.next()) {
String db = catalogSet.getString("TABLE_CAT"); String db = catalogSet.getString("TABLE_CAT");
dbs.add(db); dbs.add(db);
}
} }
} catch (Exception e) { } catch (Exception e) {
log.info("get meta catalogs failed, try to get schemas"); log.warn("get meta catalogs failed", e);
log.warn("get meta catalogs failed, try to get schemas");
} }
return dbs; return dbs;
} }
public List<String> getTables(ConnectInfo connectionInfo, String schemaName) @Override
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
throws SQLException {
// Except for special types implemented separately, the generic logic catalog does not take
// effect.
return getTables(connectInfo, schemaName);
}
protected List<String> getTables(ConnectInfo connectionInfo, String schemaName)
throws SQLException { throws SQLException {
List<String> tablesAndViews = new ArrayList<>(); List<String> tablesAndViews = new ArrayList<>();
DatabaseMetaData metaData = getDatabaseMetaData(connectionInfo);
try { try {
ResultSet resultSet = getResultSet(schemaName, metaData); try(ResultSet resultSet = getResultSet(schemaName, getDatabaseMetaData(connectionInfo))) {
while (resultSet.next()) { while (resultSet.next()) {
String name = resultSet.getString("TABLE_NAME"); String name = resultSet.getString("TABLE_NAME");
tablesAndViews.add(name); tablesAndViews.add(name);
}
} }
} catch (SQLException e) { } catch (SQLException e) {
log.error("Failed to get tables and views", e); log.error("Failed to get tables and views", e);
@@ -76,27 +91,34 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
return metaData.getTables(schemaName, schemaName, null, new String[] {"TABLE", "VIEW"}); return metaData.getTables(schemaName, schemaName, null, new String[] {"TABLE", "VIEW"});
} }
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException { throws SQLException {
List<DBColumn> dbColumns = Lists.newArrayList(); List<DBColumn> dbColumns = new ArrayList<>();
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo); // 确保连接会自动关闭
ResultSet columns = metaData.getColumns(schemaName, schemaName, tableName, null); try (ResultSet columns = getDatabaseMetaData(connectInfo).getColumns(catalog, schemaName, tableName, null)) {
while (columns.next()) { while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME"); String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME"); String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS"); String remarks = columns.getString("REMARKS");
FieldType fieldType = classifyColumnType(dataType); FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType)); dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
}
} }
return dbColumns; return dbColumns;
} }
protected DatabaseMetaData getDatabaseMetaData(ConnectInfo connectionInfo) throws SQLException { protected DatabaseMetaData getDatabaseMetaData(ConnectInfo connectionInfo) throws SQLException {
Connection connection = DriverManager.getConnection(connectionInfo.getUrl(), Connection connection = getConnection(connectionInfo);
connectionInfo.getUserName(), connectionInfo.getPassword());
return connection.getMetaData(); return connection.getMetaData();
} }
public Connection getConnection(ConnectInfo connectionInfo) throws SQLException {
final Properties properties = getProperties(connectionInfo);
return DriverManager.getConnection(connectionInfo.getUrl(), properties);
}
public FieldType classifyColumnType(String typeName) { public FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) { switch (typeName.toUpperCase()) {
case "INT": case "INT":
@@ -118,4 +140,24 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
} }
} }
public Properties getProperties(ConnectInfo connectionInfo) {
final Properties properties = new Properties();
String url = connectionInfo.getUrl().toLowerCase();
// 设置通用属性
properties.setProperty("user", connectionInfo.getUserName());
// 针对 Presto 和 Trino ssl=false 的情况,不需要设置密码
if (url.startsWith("jdbc:presto") || url.startsWith("jdbc:trino")) {
// 检查是否需要处理 SSL
if (!url.contains("ssl=false")) {
properties.setProperty("password", connectionInfo.getPassword());
}
} else {
// 针对其他数据库类型
properties.setProperty("password", connectionInfo.getPassword());
}
return properties;
}
} }

View File

@@ -18,9 +18,10 @@ public interface DbAdaptor {
List<String> getDBs(ConnectInfo connectInfo, String catalog) throws SQLException; List<String> getDBs(ConnectInfo connectInfo, String catalog) throws SQLException;
List<String> getTables(ConnectInfo connectInfo, String schemaName) throws SQLException; List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
throws SQLException;
List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName) List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException; throws SQLException;
FieldType classifyColumnType(String typeName); FieldType classifyColumnType(String typeName);

View File

@@ -19,6 +19,9 @@ public class DbAdaptorFactory {
dbAdaptorMap.put(EngineType.DUCKDB.getName(), new DuckdbAdaptor()); dbAdaptorMap.put(EngineType.DUCKDB.getName(), new DuckdbAdaptor());
dbAdaptorMap.put(EngineType.HANADB.getName(), new HanadbAdaptor()); dbAdaptorMap.put(EngineType.HANADB.getName(), new HanadbAdaptor());
dbAdaptorMap.put(EngineType.STARROCKS.getName(), new StarrocksAdaptor()); dbAdaptorMap.put(EngineType.STARROCKS.getName(), new StarrocksAdaptor());
dbAdaptorMap.put(EngineType.KYUUBI.getName(), new KyuubiAdaptor());
dbAdaptorMap.put(EngineType.PRESTO.getName(), new PrestoAdaptor());
dbAdaptorMap.put(EngineType.TRINO.getName(), new TrinoAdaptor());
} }
public static DbAdaptor getEngineAdaptor(String engineType) { public static DbAdaptor getEngineAdaptor(String engineType) {

View File

@@ -19,7 +19,7 @@ public class DuckdbAdaptor extends DefaultDbAdaptor {
return metaData.getTables(schemaName, null, null, new String[] {"TABLE", "VIEW"}); return metaData.getTables(schemaName, null, null, new String[] {"TABLE", "VIEW"});
} }
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName) public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException { throws SQLException {
List<DBColumn> dbColumns = Lists.newArrayList(); List<DBColumn> dbColumns = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo); DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);

View File

@@ -46,7 +46,7 @@ public class H2Adaptor extends BaseDbAdaptor {
return metaData.getTables(schemaName, null, null, new String[] {"TABLE", "VIEW"}); return metaData.getTables(schemaName, null, null, new String[] {"TABLE", "VIEW"});
} }
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName) public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException { throws SQLException {
List<DBColumn> dbColumns = Lists.newArrayList(); List<DBColumn> dbColumns = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo); DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);

View File

@@ -0,0 +1,84 @@
package com.tencent.supersonic.headless.core.adaptor.db;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
@Slf4j
public class KyuubiAdaptor extends BaseDbAdaptor {
/** transform YYYYMMDD to YYYY-MM-DD YYYY-MM YYYY-MM-DD(MONDAY) */
@Override
public String getDateFormat(String dateType, String dateFormat, String column) {
if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT_INT)) {
if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(%s, 'yyyy-MM')", column);
} else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(date_sub(%s, (dayofweek(%s) - 2)), 'yyyy-MM-dd')",
column, column);
} else {
return String.format(
"date_format(to_date(cast(%s as string), 'yyyyMMdd'), 'yyyy-MM-dd')",
column);
}
} else if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT)) {
if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(%s, 'yyyy-MM')", column);
} else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(date_sub(%s, (dayofweek(%s) - 2)), 'yyyy-MM-dd')",
column, column);
} else {
return column;
}
}
return column;
}
@Override
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
List<String> dbs = Lists.newArrayList();
final StringBuilder sql = new StringBuilder("SHOW DATABASES");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog);
}
try (Connection con = getConnection(connectionInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
dbs.add(rs.getString(1));
}
}
return dbs;
}
@Override
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName) throws SQLException {
List<String> tablesAndViews = new ArrayList<>();
try {
try (ResultSet resultSet = getDatabaseMetaData(connectInfo).getTables(catalog, schemaName, null, new String[] {"TABLE", "VIEW"})) {
while (resultSet.next()) {
String name = resultSet.getString("TABLE_NAME");
tablesAndViews.add(name);
}
}
} catch (SQLException e) {
log.error("Failed to get tables and views", e);
}
return tablesAndViews;
}
@Override
public String rewriteSql(String sql) {
return sql;
}
}

View File

@@ -99,7 +99,7 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
return tablesAndViews; return tablesAndViews;
} }
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName) public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException { throws SQLException {
List<DBColumn> dbColumns = Lists.newArrayList(); List<DBColumn> dbColumns = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo); DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);

View File

@@ -0,0 +1,88 @@
package com.tencent.supersonic.headless.core.adaptor.db;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import org.apache.commons.lang3.StringUtils;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
public class PrestoAdaptor extends BaseDbAdaptor {
/** transform YYYYMMDD to YYYY-MM-DD YYYY-MM YYYY-MM-DD(MONDAY) */
@Override
public String getDateFormat(String dateType, String dateFormat, String column) {
if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT_INT)) {
if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(%s, '%%Y-%%m')", column);
} else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) {
return String.format(
"date_format(date_add('day', - (day_of_week(%s) - 2), %s), '%%Y-%%m-%%d')",
column, column);
} else {
return String.format("date_format(date_parse(%s, '%%Y%%m%%d'), '%%Y-%%m-%%d')",
column);
}
} else if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT)) {
if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(%s, '%%Y-%%m')", column);
} else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) {
return String.format(
"date_format(date_add('day', - (day_of_week(%s) - 2), %s), '%%Y-%%m-%%d')",
column, column);
} else {
return column;
}
}
return column;
}
@Override
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
List<String> dbs = Lists.newArrayList();
final StringBuilder sql = new StringBuilder("SHOW SCHEMAS");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog);
}
try (Connection con = getConnection(connectionInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
dbs.add(rs.getString(1));
}
}
return dbs;
}
@Override
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
throws SQLException {
List<String> tablesAndViews = new ArrayList<>();
final StringBuilder sql = new StringBuilder("SHOW TABLES");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog).append(".").append(schemaName);
}else {
sql.append(" IN ").append(schemaName);
}
try (Connection con = getConnection(connectInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
tablesAndViews.add(rs.getString(1));
}
}
return tablesAndViews;
}
@Override
public String rewriteSql(String sql) {
return sql;
}
}

View File

@@ -1,42 +1,87 @@
package com.tencent.supersonic.headless.core.adaptor.db; package com.tencent.supersonic.headless.core.adaptor.db;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo; import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.sql.*; import java.sql.*;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Properties;
@Slf4j @Slf4j
public class StarrocksAdaptor extends MysqlAdaptor { public class StarrocksAdaptor extends MysqlAdaptor {
@Override
public List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException {
List<String> catalogs = Lists.newArrayList();
try (Connection con = DriverManager.getConnection(connectInfo.getUrl(),
connectInfo.getUserName(), connectInfo.getPassword());
Statement st = con.createStatement();
ResultSet rs = st.executeQuery("SHOW CATALOGS")) {
while (rs.next()) {
catalogs.add(rs.getString(1));
}
}
return catalogs;
}
@Override @Override
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException { public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
Assert.hasText(catalog, "StarRocks type catalog can not be null or empty");
List<String> dbs = Lists.newArrayList(); List<String> dbs = Lists.newArrayList();
try (Connection con = DriverManager.getConnection(connectionInfo.getUrl(), final StringBuilder sql = new StringBuilder("SHOW DATABASES");
connectionInfo.getUserName(), connectionInfo.getPassword()); if (StringUtils.isNotBlank(catalog)) {
Statement st = con.createStatement(); sql.append(" IN ").append(catalog);
ResultSet rs = st.executeQuery("SHOW DATABASES IN " + catalog)) { }
try (Connection con = getConnection(connectionInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) { while (rs.next()) {
dbs.add(rs.getString(1)); dbs.add(rs.getString(1));
} }
} }
return dbs; return dbs;
} }
@Override
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
throws SQLException {
List<String> tablesAndViews = new ArrayList<>();
final StringBuilder sql = new StringBuilder("SHOW TABLES");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog).append(".").append(schemaName);
}else {
sql.append(" IN ").append(schemaName);
}
try (Connection con = getConnection(connectInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
tablesAndViews.add(rs.getString(1));
}
}
return tablesAndViews;
}
@Override
public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException {
List<DBColumn> dbColumns = new ArrayList<>();
try (Connection con = getConnection(connectInfo);
Statement st = con.createStatement()) {
// 切换到指定的 catalog或 database/schema这在某些 SQL 方言中很重要
if (StringUtils.isNotBlank(catalog)) {
st.execute("SET CATALOG " + catalog);
}
// 获取 DatabaseMetaData; 需要注意调用此方法的位置(在 USE 之后)
DatabaseMetaData metaData = con.getMetaData();
// 获取特定表的列信息
try (ResultSet columns = metaData.getColumns(schemaName, schemaName, tableName, null)) {
while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS");
FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
}
}
}
return dbColumns;
}
} }

View File

@@ -0,0 +1,8 @@
package com.tencent.supersonic.headless.core.adaptor.db;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
public class TrinoAdaptor extends PrestoAdaptor {
}

View File

@@ -40,11 +40,22 @@ public class JdbcDataSourceUtils {
log.error(e.toString(), e); log.error(e.toString(), e);
return false; return false;
} }
try (Connection con = DriverManager.getConnection(database.getUrl(), database.getUsername(), // presto/trino ssl=false connection need password
database.passwordDecrypt())) { if (database.getUrl().startsWith("jdbc:presto") || database.getUrl().startsWith("jdbc:trino")) {
return con != null; if (database.getUrl().toLowerCase().contains("ssl=false")) {
} catch (SQLException e) { try (Connection con = DriverManager.getConnection(database.getUrl(), database.getUsername(), null)) {
log.error(e.toString(), e); return con != null;
} catch (SQLException e) {
log.error(e.toString(), e);
}
}
}else {
try (Connection con = DriverManager.getConnection(database.getUrl(), database.getUsername(),
database.passwordDecrypt())) {
return con != null;
} catch (SQLException e) {
log.error(e.toString(), e);
}
} }
return false; return false;

View File

@@ -17,6 +17,9 @@ public class DbParameterFactory {
parametersBuilder.put(EngineType.POSTGRESQL.getName(), new PostgresqlParametersBuilder()); parametersBuilder.put(EngineType.POSTGRESQL.getName(), new PostgresqlParametersBuilder());
parametersBuilder.put(EngineType.HANADB.getName(), new HanadbParametersBuilder()); parametersBuilder.put(EngineType.HANADB.getName(), new HanadbParametersBuilder());
parametersBuilder.put(EngineType.STARROCKS.getName(), new StarrocksParametersBuilder()); parametersBuilder.put(EngineType.STARROCKS.getName(), new StarrocksParametersBuilder());
parametersBuilder.put(EngineType.KYUUBI.getName(), new KyuubiParametersBuilder());
parametersBuilder.put(EngineType.PRESTO.getName(), new PrestoParametersBuilder());
parametersBuilder.put(EngineType.TRINO.getName(), new TrinoParametersBuilder());
parametersBuilder.put(EngineType.OTHER.getName(), new OtherParametersBuilder()); parametersBuilder.put(EngineType.OTHER.getName(), new OtherParametersBuilder());
} }

View File

@@ -29,6 +29,7 @@ public class DefaultParametersBuilder implements DbParametersBuilder {
password.setComment("密码"); password.setComment("密码");
password.setName("password"); password.setName("password");
password.setPlaceholder("请输入密码"); password.setPlaceholder("请输入密码");
password.setRequire(false);
databaseParameters.add(password); databaseParameters.add(password);
return databaseParameters; return databaseParameters;

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.server.pojo;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
public class KyuubiParametersBuilder extends DefaultParametersBuilder {
@Override
public List<DatabaseParameter> build() {
return super.build();
}
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.server.pojo;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
public class PrestoParametersBuilder extends DefaultParametersBuilder {
@Override
public List<DatabaseParameter> build() {
return super.build();
}
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.server.pojo;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
public class TrinoParametersBuilder extends DefaultParametersBuilder {
@Override
public List<DatabaseParameter> build() {
return super.build();
}
}

View File

@@ -89,15 +89,18 @@ public class DatabaseController {
@RequestMapping("/getTables") @RequestMapping("/getTables")
public List<String> getTables(@RequestParam("databaseId") Long databaseId, public List<String> getTables(@RequestParam("databaseId") Long databaseId,
@RequestParam(value = "catalog", required = false) String catalog,
@RequestParam("db") String db) throws SQLException { @RequestParam("db") String db) throws SQLException {
return databaseService.getTables(databaseId, db); return databaseService.getTables(databaseId, catalog, db);
} }
@RequestMapping("/getColumnsByName") @RequestMapping("/getColumnsByName")
public List<DBColumn> getColumnsByName(@RequestParam("databaseId") Long databaseId, public List<DBColumn> getColumnsByName(@RequestParam("databaseId") Long databaseId,
@RequestParam("db") String db, @RequestParam("table") String table) @RequestParam(name="catalog", required = false) String catalog,
@RequestParam("db") String db,
@RequestParam("table") String table)
throws SQLException { throws SQLException {
return databaseService.getColumns(databaseId, db, table); return databaseService.getColumns(databaseId, catalog, db, table);
} }
@PostMapping("/listColumnsBySql") @PostMapping("/listColumnsBySql")

View File

@@ -40,11 +40,11 @@ public interface DatabaseService {
List<String> getDbNames(Long id, String catalog) throws SQLException; List<String> getDbNames(Long id, String catalog) throws SQLException;
List<String> getTables(Long id, String db) throws SQLException; List<String> getTables(Long id, String catalog, String db) throws SQLException;
Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq) throws SQLException; Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq) throws SQLException;
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException; List<DBColumn> getColumns(Long id, String catalog, String db, String table) throws SQLException;
List<DBColumn> getColumns(Long id, String sql) throws SQLException; List<DBColumn> getColumns(Long id, String sql) throws SQLException;
} }

View File

@@ -214,10 +214,10 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
} }
@Override @Override
public List<String> getTables(Long id, String db) throws SQLException { public List<String> getTables(Long id, String catalog, String db) throws SQLException {
DatabaseResp databaseResp = getDatabase(id); DatabaseResp databaseResp = getDatabase(id);
DbAdaptor dbAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType()); DbAdaptor dbAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
return dbAdaptor.getTables(DatabaseConverter.getConnectInfo(databaseResp), db); return dbAdaptor.getTables(DatabaseConverter.getConnectInfo(databaseResp), catalog, db);
} }
@Override @Override
@@ -234,7 +234,7 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
} else { } else {
for (String table : modelBuildReq.getTables()) { for (String table : modelBuildReq.getTables()) {
List<DBColumn> columns = List<DBColumn> columns =
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getDb(), table); getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getCatalog(), modelBuildReq.getDb(), table);
dbColumnMap.put(table, columns); dbColumnMap.put(table, columns);
} }
} }
@@ -242,15 +242,15 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
} }
@Override @Override
public List<DBColumn> getColumns(Long id, String db, String table) throws SQLException { public List<DBColumn> getColumns(Long id, String catalog, String db, String table) throws SQLException {
DatabaseResp databaseResp = getDatabase(id); DatabaseResp databaseResp = getDatabase(id);
return getColumns(databaseResp, db, table); return getColumns(databaseResp, catalog, db, table);
} }
public List<DBColumn> getColumns(DatabaseResp databaseResp, String db, String table) public List<DBColumn> getColumns(DatabaseResp databaseResp, String catalog, String db, String table)
throws SQLException { throws SQLException {
DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType()); DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
return engineAdaptor.getColumns(DatabaseConverter.getConnectInfo(databaseResp), db, table); return engineAdaptor.getColumns(DatabaseConverter.getConnectInfo(databaseResp), catalog, db, table);
} }
@Override @Override

18
pom.xml
View File

@@ -47,6 +47,9 @@
<jjwt.version>0.12.3</jjwt.version> <jjwt.version>0.12.3</jjwt.version>
<alibaba.druid.version>1.2.24</alibaba.druid.version> <alibaba.druid.version>1.2.24</alibaba.druid.version>
<mysql.connector.java.version>5.1.46</mysql.connector.java.version> <mysql.connector.java.version>5.1.46</mysql.connector.java.version>
<kyuubi.version>1.10.1</kyuubi.version>
<presto.version>0.291</presto.version>
<trino.version>471</trino.version>
<mybatis.plus.version>3.5.7</mybatis.plus.version> <mybatis.plus.version>3.5.7</mybatis.plus.version>
<httpclient5.version>5.4.1</httpclient5.version> <httpclient5.version>5.4.1</httpclient5.version>
<!-- <httpcore.version>4.4.16</httpcore.version>--> <!-- <httpcore.version>4.4.16</httpcore.version>-->
@@ -208,6 +211,21 @@
<artifactId>mysql-connector-java</artifactId> <artifactId>mysql-connector-java</artifactId>
<version>${mysql.connector.java.version}</version> <version>${mysql.connector.java.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.apache.kyuubi</groupId>
<artifactId>kyuubi-hive-jdbc</artifactId>
<version>${kyuubi.version}</version>
</dependency>
<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-jdbc</artifactId>
<version>${presto.version}</version>
</dependency>
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-jdbc</artifactId>
<version>${trino.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.mockito</groupId> <groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId> <artifactId>mockito-inline</artifactId>

View File

@@ -24,6 +24,7 @@ const ModelBasicForm: React.FC<Props> = ({
mode = 'normal', mode = 'normal',
}) => { }) => {
const [currentDbLinkConfigId, setCurrentDbLinkConfigId] = useState<number>(); const [currentDbLinkConfigId, setCurrentDbLinkConfigId] = useState<number>();
const [currentCatalog, setCurrentCatalog] = useState<string>("");
const [catalogList, setCatalogList] = useState<string[]>([]); const [catalogList, setCatalogList] = useState<string[]>([]);
const [dbNameList, setDbNameList] = useState<string[]>([]); const [dbNameList, setDbNameList] = useState<string[]>([]);
const [tableNameList, setTableNameList] = useState<any[]>([]); const [tableNameList, setTableNameList] = useState<any[]>([]);
@@ -55,8 +56,8 @@ const ModelBasicForm: React.FC<Props> = ({
const onDatabaseSelect = (databaseId: number, type: string) => { const onDatabaseSelect = (databaseId: number, type: string) => {
setLoading(true); setLoading(true);
if (type === 'STARROCKS') { if (type === 'STARROCKS' || type === 'KYUUBI' || type === 'PRESTO' || type === 'TRINO') {
queryCatalogList(databaseId) queryCatalogList(databaseId);
setCatalogSelectOpen(true); setCatalogSelectOpen(true);
setDbNameList([]); setDbNameList([]);
} else { } else {
@@ -88,9 +89,11 @@ const ModelBasicForm: React.FC<Props> = ({
queryDbNameList(currentDbLinkConfigId, catalog); queryDbNameList(currentDbLinkConfigId, catalog);
} }
form.setFieldsValue({ form.setFieldsValue({
catalog: catalog,
dbName: undefined, dbName: undefined,
tableName: undefined, tableName: undefined,
}) })
setCurrentCatalog(catalog);
} }
const queryDbNameList = async (databaseId: number, catalog: string) => { const queryDbNameList = async (databaseId: number, catalog: string) => {
@@ -110,7 +113,7 @@ const ModelBasicForm: React.FC<Props> = ({
return; return;
} }
setLoading(true); setLoading(true);
const { code, data, msg } = await getTables(currentDbLinkConfigId, databaseName); const { code, data, msg } = await getTables(currentDbLinkConfigId, currentCatalog, databaseName);
setLoading(false); setLoading(false);
if (code === 200) { if (code === 200) {
const list = data || []; const list = data || [];
@@ -136,6 +139,7 @@ const ModelBasicForm: React.FC<Props> = ({
onSelect={(dbLinkConfigId: number, option) => { onSelect={(dbLinkConfigId: number, option) => {
onDatabaseSelect(dbLinkConfigId, option.type); onDatabaseSelect(dbLinkConfigId, option.type);
setCurrentDbLinkConfigId(dbLinkConfigId); setCurrentDbLinkConfigId(dbLinkConfigId);
setCurrentCatalog("");
}} }}
> >
{databaseConfigList.map((item) => ( {databaseConfigList.map((item) => (

View File

@@ -351,12 +351,24 @@ const ModelCreateForm: React.FC<CreateFormProps> = ({
let columns = fieldColumns || []; let columns = fieldColumns || [];
if (queryType === 'table_query') { if (queryType === 'table_query') {
const tableQueryString = tableQuery || ''; const tableQueryString = tableQuery || '';
const [dbName, tableName] = tableQueryString.split('.'); if (tableQueryString.split('.').length === 3) {
columns = await queryTableColumnList(modelItem.databaseId, dbName, tableName); const [catalog, dbName, tableName] = tableQueryString.split('.');
tableQueryInitValue = { columns = await queryTableColumnList(modelItem.databaseId, catalog, dbName, tableName);
dbName, tableQueryInitValue = {
tableName, catalog,
}; dbName,
tableName,
};
}
if (tableQueryString.split('.').length === 2) {
const [dbName, tableName] = tableQueryString.split('.');
columns = await queryTableColumnList(modelItem.databaseId, '', dbName, tableName);
tableQueryInitValue = {
catalog: '',
dbName,
tableName,
};
}
} }
formatterInitData(columns, tableQueryInitValue); formatterInitData(columns, tableQueryInitValue);
}; };
@@ -426,8 +438,8 @@ const ModelCreateForm: React.FC<CreateFormProps> = ({
setFields(result); setFields(result);
}; };
const queryTableColumnList = async (databaseId: number, dbName: string, tableName: string) => { const queryTableColumnList = async (databaseId: number, catalog: string, dbName: string, tableName: string) => {
const { code, data, msg } = await getColumns(databaseId, dbName, tableName); const { code, data, msg } = await getColumns(databaseId, catalog, dbName, tableName);
if (code === 200) { if (code === 200) {
const list = data || []; const list = data || [];
const columns = list.map((item: any, index: number) => { const columns = list.map((item: any, index: number) => {
@@ -563,10 +575,10 @@ const ModelCreateForm: React.FC<CreateFormProps> = ({
}} }}
onValuesChange={(value, values) => { onValuesChange={(value, values) => {
const { tableName } = value; const { tableName } = value;
const { dbName, databaseId } = values; const { catalog, dbName, databaseId } = values;
setFormDatabaseId(databaseId); setFormDatabaseId(databaseId);
if (tableName) { if (tableName) {
queryTableColumnList(databaseId, dbName, tableName); queryTableColumnList(databaseId, catalog, dbName, tableName);
} }
}} }}
className={styles.form} className={styles.form}

View File

@@ -398,21 +398,23 @@ export function getDbNames(dbId: number, catalog: string): Promise<any> {
}); });
} }
export function getTables(databaseId: number, dbName: string): Promise<any> { export function getTables(databaseId: number, catalog: string, dbName: string): Promise<any> {
return request(`${process.env.API_BASE_URL}database/getTables`, { return request(`${process.env.API_BASE_URL}database/getTables`, {
method: 'GET', method: 'GET',
params: { params: {
databaseId, databaseId,
catalog: catalog,
db: dbName, db: dbName,
}, },
}); });
} }
export function getColumns(databaseId: number, dbName: string, tableName: string): Promise<any> { export function getColumns(databaseId: number, catalog: string, dbName: string, tableName: string): Promise<any> {
return request(`${process.env.API_BASE_URL}database/getColumnsByName`, { return request(`${process.env.API_BASE_URL}database/getColumnsByName`, {
method: 'GET', method: 'GET',
params: { params: {
databaseId, databaseId,
catalog: catalog,
db: dbName, db: dbName,
table: tableName, table: tableName,
}, },