feat:Support kyuubi presto trino (#2109)

This commit is contained in:
zyclove
2025-02-26 17:33:14 +08:00
committed by GitHub
parent 11ff99cdbe
commit 5e3bafb953
31 changed files with 501 additions and 101 deletions

View File

@@ -10,7 +10,10 @@ public enum EngineType {
OTHER(7, "OTHER"),
DUCKDB(8, "DUCKDB"),
HANADB(9, "HANADB"),
STARROCKS(10, "STARROCKS"),;
STARROCKS(10, "STARROCKS"),
KYUUBI(11, "KYUUBI"),
PRESTO(12, "PRESTO"),
TRINO(13, "TRINO"),;
private Integer code;

View File

@@ -11,6 +11,8 @@ import java.util.List;
@NoArgsConstructor
public class DbSchema {
private String catalog;
private String db;
private String table;

View File

@@ -8,7 +8,9 @@ import java.util.Set;
public enum DataType {
MYSQL("mysql", "mysql", "com.mysql.cj.jdbc.Driver", "`", "`", "'", "'"),
HIVE2("hive2", "hive", "org.apache.hive.jdbc.HiveDriver", "`", "`", "`", "`"),
HIVE2("hive2", "hive", "org.apache.kyuubi.jdbc.KyuubiHiveDriver", "`", "`", "`", "`"),
KYUUBI("kyuubi", "kyuubi", "org.apache.kyuubi.jdbc.KyuubiHiveDriver", "`", "`", "`", "`"),
ORACLE("oracle", "oracle", "oracle.jdbc.driver.OracleDriver", "\"", "\"", "\"", "\""),
@@ -27,6 +29,8 @@ public enum DataType {
PRESTO("presto", "presto", "com.facebook.presto.jdbc.PrestoDriver", "\"", "\"", "\"", "\""),
TRINO("trino", "trino", "io.trino.jdbc.TrinoDriver", "\"", "\"", "\"", "\""),
MOONBOX("moonbox", "moonbox", "moonbox.jdbc.MbDriver", "`", "`", "`", "`"),
CASSANDRA("cassandra", "cassandra", "com.github.adejanovski.cassandra.jdbc.CassandraDriver", "",
@@ -46,6 +50,7 @@ public enum DataType {
TDENGINE("TAOS", "TAOS", "com.taosdata.jdbc.TSDBDriver", "'", "'", "\"", "\""),
POSTGRESQL("postgresql", "postgresql", "org.postgresql.Driver", "'", "'", "\"", "\""),
DUCKDB("duckdb", "duckdb", "org.duckdb.DuckDBDriver", "'", "'", "\"", "\"");
private String feature;

View File

@@ -19,6 +19,8 @@ public class ModelBuildReq {
private String sql;
private String catalog;
private String db;
private List<String> tables;

View File

@@ -72,7 +72,8 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
// 1. Base detection
List<EmbeddingResult> baseResults = super.detect(chatQueryContext, terms, detectDataSetIds);
boolean useLLM = Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM));
boolean useLLM =
Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM));
// 2. LLM enhanced detection
if (useLLM) {
@@ -115,7 +116,8 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
* Extract valid word segments by filtering out unwanted word natures
*/
private Set<String> extractValidSegments(String text) {
List<String> natureList = Arrays.asList(StringUtils.split(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE ), ","));
List<String> natureList = Arrays.asList(StringUtils.split(
mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE), ","));
return HanlpHelper.getSegment().seg(text).stream()
.filter(t -> natureList.stream().noneMatch(nature -> t.nature.startsWith(nature)))
.map(Term::getWord).collect(Collectors.toSet());

View File

@@ -61,7 +61,8 @@ public class MapFilter {
List<SchemaElementMatch> value = entry.getValue();
if (!CollectionUtils.isEmpty(value)) {
value.removeIf(schemaElementMatch -> StringUtils
.length(schemaElementMatch.getDetectWord()) <= 1 && !schemaElementMatch.isLlmMatched());
.length(schemaElementMatch.getDetectWord()) <= 1
&& !schemaElementMatch.isLlmMatched());
}
}
}

View File

@@ -63,6 +63,6 @@ public class MapperConfig extends ParameterConfig {
"embedding的结果再通过一次LLM来筛选这时候忽略各个向量阀值", "bool", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE =
new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']", "使用LLM召回二次处理时对问题分词词性的控制",
"分词后允许的词性才会进行向量召回", "list", "Mapper相关配置");
new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']",
"使用LLM召回二次处理时对问题分词词性的控制", "分词后允许的词性才会进行向量召回", "list", "Mapper相关配置");
}

View File

@@ -121,6 +121,18 @@
<artifactId>DmJdbcDriver18</artifactId>
<version>8.1.2.192</version>
</dependency>
<dependency>
<groupId>org.apache.kyuubi</groupId>
<artifactId>kyuubi-hive-jdbc</artifactId>
</dependency>
<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-jdbc</artifactId>
</dependency>
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-jdbc</artifactId>
</dependency>
</dependencies>

View File

@@ -5,23 +5,27 @@ import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
@Slf4j
public abstract class BaseDbAdaptor implements DbAdaptor {
@Override
public List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException {
// Apart from supporting multiple catalog types of data sources, other types will return an
// empty set by default.
return List.of();
List<String> catalogs = Lists.newArrayList();
try (Connection con = getConnection(connectInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery("SHOW CATALOGS")) {
while (rs.next()) {
catalogs.add(rs.getString(1));
}
}
return catalogs;
}
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
@@ -32,39 +36,50 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
protected List<String> getDBs(ConnectInfo connectionInfo) throws SQLException {
List<String> dbs = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectionInfo);
try {
ResultSet schemaSet = metaData.getSchemas();
try (ResultSet schemaSet = getDatabaseMetaData(connectionInfo).getSchemas()) {
while (schemaSet.next()) {
String db = schemaSet.getString("TABLE_SCHEM");
dbs.add(db);
}
}
} catch (Exception e) {
log.info("get meta schemas failed, try to get catalogs");
log.warn("get meta schemas failed", e);
log.warn("get meta schemas failed, try to get catalogs");
}
try {
ResultSet catalogSet = metaData.getCatalogs();
try (ResultSet catalogSet = getDatabaseMetaData(connectionInfo).getCatalogs()) {
while (catalogSet.next()) {
String db = catalogSet.getString("TABLE_CAT");
dbs.add(db);
}
}
} catch (Exception e) {
log.info("get meta catalogs failed, try to get schemas");
log.warn("get meta catalogs failed", e);
log.warn("get meta catalogs failed, try to get schemas");
}
return dbs;
}
public List<String> getTables(ConnectInfo connectionInfo, String schemaName)
@Override
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
throws SQLException {
// Except for special types implemented separately, the generic logic catalog does not take
// effect.
return getTables(connectInfo, schemaName);
}
protected List<String> getTables(ConnectInfo connectionInfo, String schemaName)
throws SQLException {
List<String> tablesAndViews = new ArrayList<>();
DatabaseMetaData metaData = getDatabaseMetaData(connectionInfo);
try {
ResultSet resultSet = getResultSet(schemaName, metaData);
try(ResultSet resultSet = getResultSet(schemaName, getDatabaseMetaData(connectionInfo))) {
while (resultSet.next()) {
String name = resultSet.getString("TABLE_NAME");
tablesAndViews.add(name);
}
}
} catch (SQLException e) {
log.error("Failed to get tables and views", e);
}
@@ -76,11 +91,13 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
return metaData.getTables(schemaName, schemaName, null, new String[] {"TABLE", "VIEW"});
}
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException {
List<DBColumn> dbColumns = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);
ResultSet columns = metaData.getColumns(schemaName, schemaName, tableName, null);
List<DBColumn> dbColumns = new ArrayList<>();
// 确保连接会自动关闭
try (ResultSet columns = getDatabaseMetaData(connectInfo).getColumns(catalog, schemaName, tableName, null)) {
while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME");
@@ -88,15 +105,20 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
}
}
return dbColumns;
}
protected DatabaseMetaData getDatabaseMetaData(ConnectInfo connectionInfo) throws SQLException {
Connection connection = DriverManager.getConnection(connectionInfo.getUrl(),
connectionInfo.getUserName(), connectionInfo.getPassword());
Connection connection = getConnection(connectionInfo);
return connection.getMetaData();
}
public Connection getConnection(ConnectInfo connectionInfo) throws SQLException {
final Properties properties = getProperties(connectionInfo);
return DriverManager.getConnection(connectionInfo.getUrl(), properties);
}
public FieldType classifyColumnType(String typeName) {
switch (typeName.toUpperCase()) {
case "INT":
@@ -118,4 +140,24 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
}
}
public Properties getProperties(ConnectInfo connectionInfo) {
final Properties properties = new Properties();
String url = connectionInfo.getUrl().toLowerCase();
// 设置通用属性
properties.setProperty("user", connectionInfo.getUserName());
// 针对 Presto 和 Trino ssl=false 的情况,不需要设置密码
if (url.startsWith("jdbc:presto") || url.startsWith("jdbc:trino")) {
// 检查是否需要处理 SSL
if (!url.contains("ssl=false")) {
properties.setProperty("password", connectionInfo.getPassword());
}
} else {
// 针对其他数据库类型
properties.setProperty("password", connectionInfo.getPassword());
}
return properties;
}
}

View File

@@ -18,9 +18,10 @@ public interface DbAdaptor {
List<String> getDBs(ConnectInfo connectInfo, String catalog) throws SQLException;
List<String> getTables(ConnectInfo connectInfo, String schemaName) throws SQLException;
List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
throws SQLException;
List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException;
FieldType classifyColumnType(String typeName);

View File

@@ -19,6 +19,9 @@ public class DbAdaptorFactory {
dbAdaptorMap.put(EngineType.DUCKDB.getName(), new DuckdbAdaptor());
dbAdaptorMap.put(EngineType.HANADB.getName(), new HanadbAdaptor());
dbAdaptorMap.put(EngineType.STARROCKS.getName(), new StarrocksAdaptor());
dbAdaptorMap.put(EngineType.KYUUBI.getName(), new KyuubiAdaptor());
dbAdaptorMap.put(EngineType.PRESTO.getName(), new PrestoAdaptor());
dbAdaptorMap.put(EngineType.TRINO.getName(), new TrinoAdaptor());
}
public static DbAdaptor getEngineAdaptor(String engineType) {

View File

@@ -19,7 +19,7 @@ public class DuckdbAdaptor extends DefaultDbAdaptor {
return metaData.getTables(schemaName, null, null, new String[] {"TABLE", "VIEW"});
}
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException {
List<DBColumn> dbColumns = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);

View File

@@ -46,7 +46,7 @@ public class H2Adaptor extends BaseDbAdaptor {
return metaData.getTables(schemaName, null, null, new String[] {"TABLE", "VIEW"});
}
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException {
List<DBColumn> dbColumns = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);

View File

@@ -0,0 +1,84 @@
package com.tencent.supersonic.headless.core.adaptor.db;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
@Slf4j
public class KyuubiAdaptor extends BaseDbAdaptor {
/** transform YYYYMMDD to YYYY-MM-DD YYYY-MM YYYY-MM-DD(MONDAY) */
@Override
public String getDateFormat(String dateType, String dateFormat, String column) {
if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT_INT)) {
if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(%s, 'yyyy-MM')", column);
} else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(date_sub(%s, (dayofweek(%s) - 2)), 'yyyy-MM-dd')",
column, column);
} else {
return String.format(
"date_format(to_date(cast(%s as string), 'yyyyMMdd'), 'yyyy-MM-dd')",
column);
}
} else if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT)) {
if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(%s, 'yyyy-MM')", column);
} else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(date_sub(%s, (dayofweek(%s) - 2)), 'yyyy-MM-dd')",
column, column);
} else {
return column;
}
}
return column;
}
@Override
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
List<String> dbs = Lists.newArrayList();
final StringBuilder sql = new StringBuilder("SHOW DATABASES");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog);
}
try (Connection con = getConnection(connectionInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
dbs.add(rs.getString(1));
}
}
return dbs;
}
@Override
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName) throws SQLException {
List<String> tablesAndViews = new ArrayList<>();
try {
try (ResultSet resultSet = getDatabaseMetaData(connectInfo).getTables(catalog, schemaName, null, new String[] {"TABLE", "VIEW"})) {
while (resultSet.next()) {
String name = resultSet.getString("TABLE_NAME");
tablesAndViews.add(name);
}
}
} catch (SQLException e) {
log.error("Failed to get tables and views", e);
}
return tablesAndViews;
}
@Override
public String rewriteSql(String sql) {
return sql;
}
}

View File

@@ -99,7 +99,7 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
return tablesAndViews;
}
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException {
List<DBColumn> dbColumns = Lists.newArrayList();
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);

View File

@@ -0,0 +1,88 @@
package com.tencent.supersonic.headless.core.adaptor.db;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import org.apache.commons.lang3.StringUtils;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
public class PrestoAdaptor extends BaseDbAdaptor {
/** transform YYYYMMDD to YYYY-MM-DD YYYY-MM YYYY-MM-DD(MONDAY) */
@Override
public String getDateFormat(String dateType, String dateFormat, String column) {
if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT_INT)) {
if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(%s, '%%Y-%%m')", column);
} else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) {
return String.format(
"date_format(date_add('day', - (day_of_week(%s) - 2), %s), '%%Y-%%m-%%d')",
column, column);
} else {
return String.format("date_format(date_parse(%s, '%%Y%%m%%d'), '%%Y-%%m-%%d')",
column);
}
} else if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT)) {
if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) {
return String.format("date_format(%s, '%%Y-%%m')", column);
} else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) {
return String.format(
"date_format(date_add('day', - (day_of_week(%s) - 2), %s), '%%Y-%%m-%%d')",
column, column);
} else {
return column;
}
}
return column;
}
@Override
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
List<String> dbs = Lists.newArrayList();
final StringBuilder sql = new StringBuilder("SHOW SCHEMAS");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog);
}
try (Connection con = getConnection(connectionInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
dbs.add(rs.getString(1));
}
}
return dbs;
}
@Override
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
throws SQLException {
List<String> tablesAndViews = new ArrayList<>();
final StringBuilder sql = new StringBuilder("SHOW TABLES");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog).append(".").append(schemaName);
}else {
sql.append(" IN ").append(schemaName);
}
try (Connection con = getConnection(connectInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
tablesAndViews.add(rs.getString(1));
}
}
return tablesAndViews;
}
@Override
public String rewriteSql(String sql) {
return sql;
}
}

View File

@@ -1,42 +1,87 @@
package com.tencent.supersonic.headless.core.adaptor.db;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.Assert;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
@Slf4j
public class StarrocksAdaptor extends MysqlAdaptor {
@Override
public List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException {
List<String> catalogs = Lists.newArrayList();
try (Connection con = DriverManager.getConnection(connectInfo.getUrl(),
connectInfo.getUserName(), connectInfo.getPassword());
Statement st = con.createStatement();
ResultSet rs = st.executeQuery("SHOW CATALOGS")) {
while (rs.next()) {
catalogs.add(rs.getString(1));
}
}
return catalogs;
}
@Override
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
Assert.hasText(catalog, "StarRocks type catalog can not be null or empty");
List<String> dbs = Lists.newArrayList();
try (Connection con = DriverManager.getConnection(connectionInfo.getUrl(),
connectionInfo.getUserName(), connectionInfo.getPassword());
final StringBuilder sql = new StringBuilder("SHOW DATABASES");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog);
}
try (Connection con = getConnection(connectionInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery("SHOW DATABASES IN " + catalog)) {
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
dbs.add(rs.getString(1));
}
}
return dbs;
}
@Override
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
throws SQLException {
List<String> tablesAndViews = new ArrayList<>();
final StringBuilder sql = new StringBuilder("SHOW TABLES");
if (StringUtils.isNotBlank(catalog)) {
sql.append(" IN ").append(catalog).append(".").append(schemaName);
}else {
sql.append(" IN ").append(schemaName);
}
try (Connection con = getConnection(connectInfo);
Statement st = con.createStatement();
ResultSet rs = st.executeQuery(sql.toString())) {
while (rs.next()) {
tablesAndViews.add(rs.getString(1));
}
}
return tablesAndViews;
}
@Override
public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
throws SQLException {
List<DBColumn> dbColumns = new ArrayList<>();
try (Connection con = getConnection(connectInfo);
Statement st = con.createStatement()) {
// 切换到指定的 catalog或 database/schema这在某些 SQL 方言中很重要
if (StringUtils.isNotBlank(catalog)) {
st.execute("SET CATALOG " + catalog);
}
// 获取 DatabaseMetaData; 需要注意调用此方法的位置(在 USE 之后)
DatabaseMetaData metaData = con.getMetaData();
// 获取特定表的列信息
try (ResultSet columns = metaData.getColumns(schemaName, schemaName, tableName, null)) {
while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME");
String dataType = columns.getString("TYPE_NAME");
String remarks = columns.getString("REMARKS");
FieldType fieldType = classifyColumnType(dataType);
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
}
}
}
return dbColumns;
}
}

View File

@@ -0,0 +1,8 @@
package com.tencent.supersonic.headless.core.adaptor.db;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
public class TrinoAdaptor extends PrestoAdaptor {
}

View File

@@ -40,12 +40,23 @@ public class JdbcDataSourceUtils {
log.error(e.toString(), e);
return false;
}
// presto/trino ssl=false connection need password
if (database.getUrl().startsWith("jdbc:presto") || database.getUrl().startsWith("jdbc:trino")) {
if (database.getUrl().toLowerCase().contains("ssl=false")) {
try (Connection con = DriverManager.getConnection(database.getUrl(), database.getUsername(), null)) {
return con != null;
} catch (SQLException e) {
log.error(e.toString(), e);
}
}
}else {
try (Connection con = DriverManager.getConnection(database.getUrl(), database.getUsername(),
database.passwordDecrypt())) {
return con != null;
} catch (SQLException e) {
log.error(e.toString(), e);
}
}
return false;
}

View File

@@ -17,6 +17,9 @@ public class DbParameterFactory {
parametersBuilder.put(EngineType.POSTGRESQL.getName(), new PostgresqlParametersBuilder());
parametersBuilder.put(EngineType.HANADB.getName(), new HanadbParametersBuilder());
parametersBuilder.put(EngineType.STARROCKS.getName(), new StarrocksParametersBuilder());
parametersBuilder.put(EngineType.KYUUBI.getName(), new KyuubiParametersBuilder());
parametersBuilder.put(EngineType.PRESTO.getName(), new PrestoParametersBuilder());
parametersBuilder.put(EngineType.TRINO.getName(), new TrinoParametersBuilder());
parametersBuilder.put(EngineType.OTHER.getName(), new OtherParametersBuilder());
}

View File

@@ -29,6 +29,7 @@ public class DefaultParametersBuilder implements DbParametersBuilder {
password.setComment("密码");
password.setName("password");
password.setPlaceholder("请输入密码");
password.setRequire(false);
databaseParameters.add(password);
return databaseParameters;

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.server.pojo;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
public class KyuubiParametersBuilder extends DefaultParametersBuilder {
@Override
public List<DatabaseParameter> build() {
return super.build();
}
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.server.pojo;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
public class PrestoParametersBuilder extends DefaultParametersBuilder {
@Override
public List<DatabaseParameter> build() {
return super.build();
}
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.headless.server.pojo;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
public class TrinoParametersBuilder extends DefaultParametersBuilder {
@Override
public List<DatabaseParameter> build() {
return super.build();
}
}

View File

@@ -89,15 +89,18 @@ public class DatabaseController {
@RequestMapping("/getTables")
public List<String> getTables(@RequestParam("databaseId") Long databaseId,
@RequestParam(value = "catalog", required = false) String catalog,
@RequestParam("db") String db) throws SQLException {
return databaseService.getTables(databaseId, db);
return databaseService.getTables(databaseId, catalog, db);
}
@RequestMapping("/getColumnsByName")
public List<DBColumn> getColumnsByName(@RequestParam("databaseId") Long databaseId,
@RequestParam("db") String db, @RequestParam("table") String table)
@RequestParam(name="catalog", required = false) String catalog,
@RequestParam("db") String db,
@RequestParam("table") String table)
throws SQLException {
return databaseService.getColumns(databaseId, db, table);
return databaseService.getColumns(databaseId, catalog, db, table);
}
@PostMapping("/listColumnsBySql")

View File

@@ -40,11 +40,11 @@ public interface DatabaseService {
List<String> getDbNames(Long id, String catalog) throws SQLException;
List<String> getTables(Long id, String db) throws SQLException;
List<String> getTables(Long id, String catalog, String db) throws SQLException;
Map<String, List<DBColumn>> getDbColumns(ModelBuildReq modelBuildReq) throws SQLException;
List<DBColumn> getColumns(Long id, String db, String table) throws SQLException;
List<DBColumn> getColumns(Long id, String catalog, String db, String table) throws SQLException;
List<DBColumn> getColumns(Long id, String sql) throws SQLException;
}

View File

@@ -214,10 +214,10 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
}
@Override
public List<String> getTables(Long id, String db) throws SQLException {
public List<String> getTables(Long id, String catalog, String db) throws SQLException {
DatabaseResp databaseResp = getDatabase(id);
DbAdaptor dbAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
return dbAdaptor.getTables(DatabaseConverter.getConnectInfo(databaseResp), db);
return dbAdaptor.getTables(DatabaseConverter.getConnectInfo(databaseResp), catalog, db);
}
@Override
@@ -234,7 +234,7 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
} else {
for (String table : modelBuildReq.getTables()) {
List<DBColumn> columns =
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getDb(), table);
getColumns(modelBuildReq.getDatabaseId(), modelBuildReq.getCatalog(), modelBuildReq.getDb(), table);
dbColumnMap.put(table, columns);
}
}
@@ -242,15 +242,15 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
}
@Override
public List<DBColumn> getColumns(Long id, String db, String table) throws SQLException {
public List<DBColumn> getColumns(Long id, String catalog, String db, String table) throws SQLException {
DatabaseResp databaseResp = getDatabase(id);
return getColumns(databaseResp, db, table);
return getColumns(databaseResp, catalog, db, table);
}
public List<DBColumn> getColumns(DatabaseResp databaseResp, String db, String table)
public List<DBColumn> getColumns(DatabaseResp databaseResp, String catalog, String db, String table)
throws SQLException {
DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType());
return engineAdaptor.getColumns(DatabaseConverter.getConnectInfo(databaseResp), db, table);
return engineAdaptor.getColumns(DatabaseConverter.getConnectInfo(databaseResp), catalog, db, table);
}
@Override

18
pom.xml
View File

@@ -47,6 +47,9 @@
<jjwt.version>0.12.3</jjwt.version>
<alibaba.druid.version>1.2.24</alibaba.druid.version>
<mysql.connector.java.version>5.1.46</mysql.connector.java.version>
<kyuubi.version>1.10.1</kyuubi.version>
<presto.version>0.291</presto.version>
<trino.version>471</trino.version>
<mybatis.plus.version>3.5.7</mybatis.plus.version>
<httpclient5.version>5.4.1</httpclient5.version>
<!-- <httpcore.version>4.4.16</httpcore.version>-->
@@ -208,6 +211,21 @@
<artifactId>mysql-connector-java</artifactId>
<version>${mysql.connector.java.version}</version>
</dependency>
<dependency>
<groupId>org.apache.kyuubi</groupId>
<artifactId>kyuubi-hive-jdbc</artifactId>
<version>${kyuubi.version}</version>
</dependency>
<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-jdbc</artifactId>
<version>${presto.version}</version>
</dependency>
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-jdbc</artifactId>
<version>${trino.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>

View File

@@ -24,6 +24,7 @@ const ModelBasicForm: React.FC<Props> = ({
mode = 'normal',
}) => {
const [currentDbLinkConfigId, setCurrentDbLinkConfigId] = useState<number>();
const [currentCatalog, setCurrentCatalog] = useState<string>("");
const [catalogList, setCatalogList] = useState<string[]>([]);
const [dbNameList, setDbNameList] = useState<string[]>([]);
const [tableNameList, setTableNameList] = useState<any[]>([]);
@@ -55,8 +56,8 @@ const ModelBasicForm: React.FC<Props> = ({
const onDatabaseSelect = (databaseId: number, type: string) => {
setLoading(true);
if (type === 'STARROCKS') {
queryCatalogList(databaseId)
if (type === 'STARROCKS' || type === 'KYUUBI' || type === 'PRESTO' || type === 'TRINO') {
queryCatalogList(databaseId);
setCatalogSelectOpen(true);
setDbNameList([]);
} else {
@@ -88,9 +89,11 @@ const ModelBasicForm: React.FC<Props> = ({
queryDbNameList(currentDbLinkConfigId, catalog);
}
form.setFieldsValue({
catalog: catalog,
dbName: undefined,
tableName: undefined,
})
setCurrentCatalog(catalog);
}
const queryDbNameList = async (databaseId: number, catalog: string) => {
@@ -110,7 +113,7 @@ const ModelBasicForm: React.FC<Props> = ({
return;
}
setLoading(true);
const { code, data, msg } = await getTables(currentDbLinkConfigId, databaseName);
const { code, data, msg } = await getTables(currentDbLinkConfigId, currentCatalog, databaseName);
setLoading(false);
if (code === 200) {
const list = data || [];
@@ -136,6 +139,7 @@ const ModelBasicForm: React.FC<Props> = ({
onSelect={(dbLinkConfigId: number, option) => {
onDatabaseSelect(dbLinkConfigId, option.type);
setCurrentDbLinkConfigId(dbLinkConfigId);
setCurrentCatalog("");
}}
>
{databaseConfigList.map((item) => (

View File

@@ -351,13 +351,25 @@ const ModelCreateForm: React.FC<CreateFormProps> = ({
let columns = fieldColumns || [];
if (queryType === 'table_query') {
const tableQueryString = tableQuery || '';
const [dbName, tableName] = tableQueryString.split('.');
columns = await queryTableColumnList(modelItem.databaseId, dbName, tableName);
if (tableQueryString.split('.').length === 3) {
const [catalog, dbName, tableName] = tableQueryString.split('.');
columns = await queryTableColumnList(modelItem.databaseId, catalog, dbName, tableName);
tableQueryInitValue = {
catalog,
dbName,
tableName,
};
}
if (tableQueryString.split('.').length === 2) {
const [dbName, tableName] = tableQueryString.split('.');
columns = await queryTableColumnList(modelItem.databaseId, '', dbName, tableName);
tableQueryInitValue = {
catalog: '',
dbName,
tableName,
};
}
}
formatterInitData(columns, tableQueryInitValue);
};
@@ -426,8 +438,8 @@ const ModelCreateForm: React.FC<CreateFormProps> = ({
setFields(result);
};
const queryTableColumnList = async (databaseId: number, dbName: string, tableName: string) => {
const { code, data, msg } = await getColumns(databaseId, dbName, tableName);
const queryTableColumnList = async (databaseId: number, catalog: string, dbName: string, tableName: string) => {
const { code, data, msg } = await getColumns(databaseId, catalog, dbName, tableName);
if (code === 200) {
const list = data || [];
const columns = list.map((item: any, index: number) => {
@@ -563,10 +575,10 @@ const ModelCreateForm: React.FC<CreateFormProps> = ({
}}
onValuesChange={(value, values) => {
const { tableName } = value;
const { dbName, databaseId } = values;
const { catalog, dbName, databaseId } = values;
setFormDatabaseId(databaseId);
if (tableName) {
queryTableColumnList(databaseId, dbName, tableName);
queryTableColumnList(databaseId, catalog, dbName, tableName);
}
}}
className={styles.form}

View File

@@ -398,21 +398,23 @@ export function getDbNames(dbId: number, catalog: string): Promise<any> {
});
}
export function getTables(databaseId: number, dbName: string): Promise<any> {
export function getTables(databaseId: number, catalog: string, dbName: string): Promise<any> {
return request(`${process.env.API_BASE_URL}database/getTables`, {
method: 'GET',
params: {
databaseId,
catalog: catalog,
db: dbName,
},
});
}
export function getColumns(databaseId: number, dbName: string, tableName: string): Promise<any> {
export function getColumns(databaseId: number, catalog: string, dbName: string, tableName: string): Promise<any> {
return request(`${process.env.API_BASE_URL}database/getColumnsByName`, {
method: 'GET',
params: {
databaseId,
catalog: catalog,
db: dbName,
table: tableName,
},