(feature)(Headless) arrow flight sql endpoint (#634) (#1136)

This commit is contained in:
jipeli
2024-06-12 15:02:55 +08:00
committed by GitHub
parent 91a6308d9e
commit abf440c4d4
19 changed files with 697 additions and 7 deletions

View File

@@ -23,6 +23,8 @@ public interface UserAdaptor {
String login(UserReq userReq, HttpServletRequest request);
String login(UserReq userReq, String appKey);
List<User> getUserByOrg(String key);
Set<String> getUserAllOrgId(String userName);

View File

@@ -21,6 +21,8 @@ public interface UserService {
String login(UserReq userCmd, HttpServletRequest request);
String login(UserReq userCmd, String appKey);
Set<String> getUserAllOrgId(String userName);
List<User> getUserByOrg(String key);

View File

@@ -11,4 +11,6 @@ public interface UserStrategy {
User findUser(HttpServletRequest request, HttpServletResponse response);
User findUser(String token, String appKey);
}

View File

@@ -5,10 +5,9 @@ import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
import com.tencent.supersonic.common.pojo.SystemConfig;
import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.util.CollectionUtils;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.util.CollectionUtils;
public final class UserHolder {
@@ -20,6 +19,15 @@ public final class UserHolder {
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
User user = REPO.findUser(request, response);
return getUser(user);
}
public static User findUser(String token, String appKey) {
User user = REPO.findUser(token, appKey);
return getUser(user);
}
private static User getUser(User user) {
SystemConfigService sysParameterService = ContextUtils.getBean(SystemConfigService.class);
SystemConfig systemConfig = sysParameterService.getSystemConfig();
if (!CollectionUtils.isEmpty(systemConfig.getAdmins())

View File

@@ -90,25 +90,43 @@ public class DefaultUserAdaptor implements UserAdaptor {
@Override
public String login(UserReq userReq, HttpServletRequest request) {
UserTokenUtils userTokenUtils = ContextUtils.getBean(UserTokenUtils.class);
try {
UserWithPassword user = getUserWithPassword(userReq);
return userTokenUtils.generateToken(user, request);
} catch (Exception e) {
throw new RuntimeException("password encrypt error, please try again");
}
}
@Override
public String login(UserReq userReq, String appKey) {
UserTokenUtils userTokenUtils = ContextUtils.getBean(UserTokenUtils.class);
try {
UserWithPassword user = getUserWithPassword(userReq);
return userTokenUtils.generateToken(user, appKey);
} catch (Exception e) {
throw new RuntimeException("password encrypt error, please try again");
}
}
private UserWithPassword getUserWithPassword(UserReq userReq) {
UserDO userDO = getUser(userReq.getName());
if (userDO == null) {
throw new RuntimeException("user not exist,please register");
}
try {
String password = AESEncryptionUtil.encrypt(userReq.getPassword(),
AESEncryptionUtil.getBytesFromString(userDO.getSalt()));
if (userDO.getPassword().equals(password)) {
UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
return userTokenUtils.generateToken(user, request);
return user;
} else {
throw new RuntimeException("password not correct, please try again");
}
} catch (Exception e) {
throw new RuntimeException("password encrypt error, please try again");
}
}
@Override

View File

@@ -72,4 +72,9 @@ public class UserServiceImpl implements UserService {
return ComponentFactory.getUserAdaptor().login(userReq, request);
}
@Override
public String login(UserReq userReq, String appKey) {
return ComponentFactory.getUserAdaptor().login(userReq, appKey);
}
}

View File

@@ -20,4 +20,9 @@ public class FakeUserStrategy implements UserStrategy {
return User.getFakeUser();
}
@Override
public User findUser(String token, String appKey) {
return User.getFakeUser();
}
}

View File

@@ -28,4 +28,9 @@ public class HttpHeaderUserStrategy implements UserStrategy {
public User findUser(HttpServletRequest request, HttpServletResponse response) {
return userTokenUtils.getUser(request);
}
@Override
public User findUser(String token, String appKey) {
return userTokenUtils.getUser(token, appKey);
}
}

View File

@@ -37,6 +37,11 @@ public class UserTokenUtils {
}
public String generateToken(UserWithPassword user, HttpServletRequest request) {
String appKey = getAppKey(request);
return generateToken(user, appKey);
}
public String generateToken(UserWithPassword user, String appKey) {
Map<String, Object> claims = new HashMap<>(5);
claims.put(TOKEN_USER_ID, user.getId());
claims.put(TOKEN_USER_NAME, StringUtils.isEmpty(user.getName()) ? "" : user.getName());
@@ -44,7 +49,6 @@ public class UserTokenUtils {
claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName());
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
claims.put(TOKEN_IS_ADMIN, user.getIsAdmin());
String appKey = getAppKey(request);
return generate(claims, appKey);
}
@@ -61,6 +65,15 @@ public class UserTokenUtils {
public User getUser(HttpServletRequest request) {
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
final Claims claims = getClaims(token, request);
return getUser(claims);
}
public User getUser(String token, String appKey) {
final Claims claims = getClaims(token, appKey);
return getUser(claims);
}
private User getUser(Claims claims) {
Long userId = Long.parseLong(claims.getOrDefault(TOKEN_USER_ID, 0).toString());
String userName = String.valueOf(claims.get(TOKEN_USER_NAME));
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
@@ -92,6 +105,16 @@ public class UserTokenUtils {
Claims claims;
try {
String appKey = getAppKey(request);
claims = getClaims(token, appKey);
} catch (Exception e) {
throw new AccessException("parse user info from token failed :" + token);
}
return claims;
}
private Claims getClaims(String token, String appKey) {
Claims claims;
try {
String tokenSecret = getTokenSecret(appKey);
claims = Jwts.parser()
.setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))

View File

@@ -89,6 +89,12 @@
<groupId>org.apache.calcite.avatica</groupId>
<artifactId>avatica-core</artifactId>
<version>${calcite.avatica.version}</version>
<exclusions>
<exclusion>
<artifactId>protobuf-java</artifactId>
<groupId>com.google.protobuf</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.calcite</groupId>
@@ -99,6 +105,10 @@
<artifactId>log4j</artifactId>
<groupId>log4j</groupId>
</exclusion>
<exclusion>
<artifactId>protobuf-java</artifactId>
<groupId>com.google.protobuf</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>

View File

@@ -119,6 +119,22 @@
<version>${postgresql.version}</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf-java.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>flight-sql</artifactId>
<version>${flight-sql.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-jdbc</artifactId>
<version>${arrow-jdbc.version}</version>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,93 @@
package com.tencent.supersonic.headless.server.listener;
import com.tencent.supersonic.headless.server.service.FlightService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import lombok.extern.slf4j.Slf4j;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.stereotype.Component;
/**
* arrow flight jdbc server listener
*/
@Component
@Slf4j
public class FlightSqlListener implements CommandLineRunner {
@Value("${s2.flightSql.enable:false}")
private Boolean enable = false;
@Value("${s2.flightSql.host:localhost}")
private String host = "localhost";
@Value("${s2.flightSql.port:9081}")
private Integer port = 9081;
@Value("${s2.flightSql.executor:4}")
private Integer executor = 4;
@Value("${s2.flightSql.queue:128}")
private Integer queue = 128;
@Value("${s2.flightSql.expireMinute:10}")
private Integer expireMinute = 10;
private final FlightService flightService;
private ExecutorService executorService;
private FlightServer flightServer;
private BufferAllocator allocator;
public FlightSqlListener(FlightService flightService) {
this.allocator = new RootAllocator();
this.flightService = flightService;
this.flightService.setLocation(host, port);
executorService = Executors.newFixedThreadPool(executor);
this.flightService.setExecutorService(executorService, queue, expireMinute);
Location listenLocation = Location.forGrpcInsecure(host, port);
flightServer = FlightServer.builder(allocator, listenLocation, this.flightService)
.build();
}
public String getHost() {
return host;
}
public Integer getPort() {
return port;
}
public void startServer() {
try {
log.info("Arrow Flight JDBC server started on {} {}", host, port);
flightServer.start();
} catch (Exception e) {
log.error("FlightSqlListener start error {}", e);
}
}
@Override
public void run(String... args) throws Exception {
if (enable) {
new Thread() {
@Override
public void run() {
try {
startServer();
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
flightServer.close();
allocator.close();
} catch (Exception e) {
log.error("flightServer close error {}", e);
}
}));
//flightServer.awaitTermination();
} catch (Exception e) {
log.error("run error {}", e);
}
}
}.start();
}
}
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.headless.server.service;
import java.util.concurrent.ExecutorService;
import org.apache.arrow.flight.sql.FlightSqlProducer;
public interface FlightService extends FlightSqlProducer {
void setLocation(String host, Integer port);
void setExecutorService(ExecutorService executorService, Integer queue, Integer expireMinute);
}

View File

@@ -0,0 +1,337 @@
package com.tencent.supersonic.headless.server.service.impl;
import static com.google.protobuf.Any.pack;
import static com.google.protobuf.ByteString.copyFrom;
import static java.util.Collections.singletonList;
import static java.util.UUID.randomUUID;
import static org.apache.arrow.adapter.jdbc.JdbcToArrow.sqlToArrowVectorIterator;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
import com.tencent.supersonic.auth.api.authentication.service.UserService;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.Param;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.service.FlightService;
import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.utils.FlightUtils;
import java.nio.charset.StandardCharsets;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.Arrays;
import java.util.Calendar;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import javax.sql.RowSetMetaData;
import javax.sql.rowset.CachedRowSet;
import javax.sql.rowset.RowSetFactory;
import javax.sql.rowset.RowSetMetaDataImpl;
import javax.sql.rowset.RowSetProvider;
import lombok.extern.slf4j.Slf4j;
import org.apache.arrow.adapter.jdbc.ArrowVectorIterator;
import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightConstants;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.BasicFlightSqlProducer;
import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest;
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest;
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery;
import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
/**
* arrow flight FlightSqlProducer
*/
@Slf4j
@Service("FlightService")
public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightService {
private String host;
private Integer port;
private ExecutorService executorService;
private Cache<ByteString, SemanticQueryReq> preparedStatementCache;
private final String dataSetIdHeaderKey = "dataSetId";
private final String nameHeaderKey = "name";
private final String passwordHeaderKey = "password";
private final Calendar defaultCalendar = JdbcToArrowUtils.getUtcCalendar();
private final QueryService queryService;
private final AuthenticationConfig authenticationConfig;
private final UserService userService;
public FlightServiceImpl(QueryService queryService,
AuthenticationConfig authenticationConfig,
UserService userService) {
this.queryService = queryService;
this.authenticationConfig = authenticationConfig;
this.userService = userService;
}
public void setLocation(String host, Integer port) {
this.host = host;
this.port = port;
}
@Override
public void setExecutorService(ExecutorService executorService, Integer queue, Integer expireMinute) {
this.executorService = executorService;
this.preparedStatementCache =
CacheBuilder.newBuilder()
.maximumSize(queue)
.expireAfterWrite(expireMinute, TimeUnit.MINUTES)
.build();
}
@Override
public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) {
return super.getFlightInfo(callContext, flightDescriptor);
}
@Override
public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context,
final ServerStreamListener listener) {
final ByteString handle = ticketStatementQuery.getStatementHandle();
log.info("getStreamStatement {} ", handle);
executeQuery(handle, listener);
}
@Override
public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context,
final FlightDescriptor descriptor) {
try {
ByteString preparedStatementHandle = addPrepared(context, request.getQuery());
TicketStatementQuery ticket = TicketStatementQuery.newBuilder()
.setStatementHandle(preparedStatementHandle)
.build();
return getFlightInfoForSchema(ticket, descriptor, null);
} catch (Exception e) {
log.error("getFlightInfoStatement error {}", e);
}
return null;
}
@Override
public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context,
final ServerStreamListener listener) {
log.info("getStreamPreparedStatement {}", command.getPreparedStatementHandle());
executeQuery(command.getPreparedStatementHandle(), listener);
}
private void executeQuery(ByteString hander, final ServerStreamListener listener) {
SemanticQueryReq semanticQueryReq = preparedStatementCache.getIfPresent(hander);
if (Objects.isNull(semanticQueryReq)) {
listener.error(CallStatus.INTERNAL
.withDescription("Failed to get prepared statement: empty")
.toRuntimeException());
log.error("getStreamPreparedStatement error {}", hander);
listener.completed();
return;
}
executorService.submit(() -> {
BufferAllocator rootAllocator = new RootAllocator();
try {
Optional<Param> authOpt = semanticQueryReq.getParams().stream()
.filter(p -> p.getName().equals(authenticationConfig.getTokenHttpHeaderKey())).findFirst();
if (authOpt.isPresent()) {
User user = UserHolder.findUser(authOpt.get().getValue(),
authenticationConfig.getTokenHttpHeaderAppKey());
SemanticQueryResp resp = queryService.queryByReq(semanticQueryReq, user);
ResultSet resultSet = semanticQueryRespToResultSet(resp, semanticQueryReq.getDataSetId());
final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), defaultCalendar);
try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
final VectorLoader loader = new VectorLoader(vectorSchemaRoot);
listener.start(vectorSchemaRoot);
final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator);
while (iterator.hasNext()) {
final VectorSchemaRoot batch = iterator.next();
if (batch.getRowCount() == 0) {
break;
}
final VectorUnloader unloader = new VectorUnloader(batch);
loader.load(unloader.getRecordBatch());
listener.putNext();
vectorSchemaRoot.clear();
}
listener.putNext();
}
}
} catch (Exception e) {
listener.error(CallStatus.INTERNAL
.withDescription(String.format("Failed to get exec statement %s", e.getMessage()))
.toRuntimeException());
log.error("getStreamPreparedStatement error {}", hander);
} finally {
preparedStatementCache.invalidate(hander);
listener.completed();
rootAllocator.close();
}
});
}
@Override
public void closePreparedStatement(final ActionClosePreparedStatementRequest request, final CallContext context,
final StreamListener<Result> listener) {
log.info("closePreparedStatement {}", request.getPreparedStatementHandle());
listener.onCompleted();
}
@Override
public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command,
final CallContext context,
final FlightDescriptor descriptor) {
return getFlightInfoForSchema(command, descriptor, null);
}
@Override
public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext context,
final StreamListener<Result> listener) {
prepared(request, context, listener);
}
private ByteString addPrepared(final CallContext context, String query) throws Exception {
if (Arrays.asList(dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey).stream()
.anyMatch(h -> !context.getMiddleware(FlightConstants.HEADER_KEY).headers().containsKey(h))) {
throw new Exception(String.format("Failed to create prepared statement: HeaderCallOption miss %s %s %s",
dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey));
}
Long dataSetId = Long.valueOf(
context.getMiddleware(FlightConstants.HEADER_KEY).headers().get(dataSetIdHeaderKey));
if (StringUtils.isBlank(query)) {
throw new Exception("Failed to create prepared statement: query is empty");
}
try {
String auth = getUserAuth(context.getMiddleware(FlightConstants.HEADER_KEY).headers());
if (StringUtils.isBlank(auth)) {
throw new Exception("auth empty");
}
final ByteString preparedStatementHandle = copyFrom(
randomUUID().toString().getBytes(StandardCharsets.UTF_8));
QuerySqlReq querySqlReq = new QuerySqlReq();
querySqlReq.setDataSetId(dataSetId);
querySqlReq.setSql(query);
querySqlReq.setParams(Arrays.asList(new Param(authenticationConfig.getTokenHttpHeaderKey(), auth)));
preparedStatementCache.put(preparedStatementHandle, querySqlReq);
log.info("createPreparedStatement {} {} {} ", preparedStatementHandle, dataSetId, query);
return preparedStatementHandle;
} catch (Exception e) {
throw e;
}
}
private void prepared(final ActionCreatePreparedStatementRequest request, final CallContext context,
final StreamListener<Result> listener) {
try {
ByteString preparedStatementHandle = addPrepared(context, request.getQuery());
final ActionCreatePreparedStatementResult result = ActionCreatePreparedStatementResult.newBuilder()
.setDatasetSchema(ByteString.EMPTY)
.setParameterSchema(ByteString.empty())
.setPreparedStatementHandle(preparedStatementHandle)
.build();
listener.onNext(new Result(pack(result).toByteArray()));
} catch (Exception e) {
listener.onError(CallStatus.INTERNAL
.withDescription(String.format("Failed to create prepared statement: %s", e.getMessage()))
.toRuntimeException());
} finally {
listener.onCompleted();
}
}
@Override
protected <T extends Message> List<FlightEndpoint> determineEndpoints(T t, FlightDescriptor flightDescriptor,
Schema schema) {
throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
}
private <T extends Message> FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor,
final Schema schema) {
final Ticket ticket = new Ticket(pack(request).toByteArray());
Location listenLocation = Location.forGrpcInsecure(host, port);
final List<FlightEndpoint> endpoints = singletonList(new FlightEndpoint(ticket, listenLocation));
return new FlightInfo(schema, descriptor, endpoints, -1, -1);
}
private String getUserAuth(CallHeaders callHeaders) throws Exception {
UserReq userReq = new UserReq();
userReq.setName(callHeaders.get(nameHeaderKey));
userReq.setPassword(callHeaders.get(passwordHeaderKey));
if (StringUtils.isBlank(userReq.getName()) || StringUtils.isBlank(userReq.getPassword())) {
throw new Exception("name or password is empty");
}
String auth = userService.login(userReq, authenticationConfig.getTokenDefaultAppKey());
return auth;
}
private ResultSet semanticQueryRespToResultSet(SemanticQueryResp resp, Long dataSetId) throws SQLException {
RowSetFactory factory = RowSetProvider.newFactory();
CachedRowSet rowset = factory.createCachedRowSet();
RowSetMetaData rowSetMetaData = new RowSetMetaDataImpl();
int columnNum = resp.getColumns().size();
rowSetMetaData.setColumnCount(columnNum);
for (int i = 1; i <= columnNum; i++) {
String columnName = resp.getColumns().get(i - 1).getNameEn();
rowSetMetaData.setColumnName(i, columnName);
Optional<Map<String, Object>> valOpt = resp.getResultList().stream()
.filter(r -> r.containsKey(columnName) && Objects.nonNull(r.get(columnName))).findFirst();
if (valOpt.isPresent()) {
int type = FlightUtils.resolveType(valOpt.get());
rowSetMetaData.setColumnType(i, type);
rowSetMetaData.setNullable(i, FlightUtils.isNullable(type));
} else {
rowSetMetaData.setNullable(i, ResultSetMetaData.columnNullable);
rowSetMetaData.setColumnType(i, Types.VARCHAR);
}
rowSetMetaData.setCatalogName(i, String.valueOf(dataSetId));
rowSetMetaData.setSchemaName(i, dataSetIdHeaderKey);
}
rowset.setMetaData(rowSetMetaData);
for (Map<String, Object> row : resp.getResultList()) {
rowset.moveToInsertRow();
for (int i = 1; i <= columnNum; i++) {
String columnName = resp.getColumns().get(i - 1).getNameEn();
if (row.containsKey(columnName)) {
rowset.updateObject(i, row.get(columnName));
} else {
rowset.updateObject(i, null);
}
}
rowset.insertRow();
rowset.moveToCurrentRow();
}
return rowset;
}
}

View File

@@ -0,0 +1,47 @@
package com.tencent.supersonic.headless.server.utils;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.regex.Pattern;
/**
* tools for arrow flight sql
*/
public class FlightUtils {
public static int resolveType(Object value) {
if (value instanceof Long) {
return Types.BIGINT;
}
if (value instanceof Integer) {
return Types.INTEGER;
}
if (value instanceof Double) {
return Types.DOUBLE;
}
if (value instanceof String) {
String val = String.valueOf(value);
if (Pattern.matches("^\\d+$", val)) {
return Types.BIGINT;
} else if (Pattern.matches("^\\d+\\.\\d+$", val)) {
return Types.DECIMAL;
} else if (Pattern.matches("^\\d{4}-\\d{2}-\\d{2}$", val)) {
return Types.DATE;
} else if (Pattern.matches("^\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}$", val)) {
return Types.TIME;
}
}
return Types.VARCHAR;
}
public static int isNullable(int sqlType) throws SQLException {
switch (sqlType) {
case Types.VARCHAR:
case Types.DECIMAL:
return ResultSetMetaData.columnNullable;
default:
return ResultSetMetaData.columnNullableUnknown;
}
}
}

View File

@@ -0,0 +1,101 @@
package com.tencent.supersonic.headless;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.auth.authentication.strategy.FakeUserStrategy;
import com.tencent.supersonic.headless.server.listener.FlightSqlListener;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.FlightCallHeaders;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.HeaderCallOption;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.memory.RootAllocator;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
public class FlightSqlTest extends BaseTest {
@Autowired
private FlightSqlListener flightSqlListener;
@Autowired
private FakeUserStrategy fakeUserStrategy;
@Test
void test01() throws Exception {
String host = flightSqlListener.getHost();
Integer port = flightSqlListener.getPort();
UserHolder.setStrategy(fakeUserStrategy);
flightSqlListener.startServer();
FlightSqlClient sqlClient = new FlightSqlClient(
FlightClient.builder(new RootAllocator(Integer.MAX_VALUE), Location.forGrpcInsecure(host, port))
.build());
CallHeaders headers = new FlightCallHeaders();
headers.insert("dataSetId", "1");
headers.insert("name", "admin");
headers.insert("password", "admin");
HeaderCallOption headerOption = new HeaderCallOption(headers);
try (final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(
"SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门",
headerOption)) {
final FlightInfo info = preparedStatement.execute();
FlightStream stream = sqlClient.getStream(info
.getEndpoints()
.get(0).getTicket());
int rowCnt = 0;
int colCnt = 0;
while (stream.next()) {
if (stream.getRoot().getRowCount() > 0) {
colCnt = stream.getRoot().getFieldVectors().size();
rowCnt += stream.getRoot().getRowCount();
}
}
assertEquals(2, colCnt);
assertTrue(rowCnt > 0);
} catch (Exception e) {
e.printStackTrace();
}
}
@Test
void test02() throws Exception {
String host = flightSqlListener.getHost();
Integer port = flightSqlListener.getPort();
UserHolder.setStrategy(fakeUserStrategy);
flightSqlListener.startServer();
FlightSqlClient sqlClient = new FlightSqlClient(
FlightClient.builder(new RootAllocator(Integer.MAX_VALUE), Location.forGrpcInsecure(host, port))
.build());
CallHeaders headers = new FlightCallHeaders();
headers.insert("dataSetId", "1");
headers.insert("name", "admin");
headers.insert("password", "admin");
HeaderCallOption headerOption = new HeaderCallOption(headers);
try {
FlightInfo flightInfo = sqlClient.execute("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门",
headerOption);
FlightStream stream = sqlClient.getStream(flightInfo
.getEndpoints()
.get(0).getTicket());
int rowCnt = 0;
int colCnt = 0;
while (stream.next()) {
if (stream.getRoot().getRowCount() > 0) {
colCnt = stream.getRoot().getFieldVectors().size();
rowCnt += stream.getRoot().getRowCount();
}
}
assertEquals(2, colCnt);
assertTrue(rowCnt > 0);
} catch (Exception e) {
e.printStackTrace();
}
}
}

View File

@@ -1,5 +1,5 @@
-- sample user
MERGE INTO s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1);
MERGE INTO s2_user (id, `name`, password, salt, display_name, email, is_admin) values (1, 'admin','c3VwZXJzb25pY0BiaWNvbTD12g9wGXESwL7+o7xUW90=','jGl25bVBBBW96Qi9Te4V3w==','admin','admin@xx.com', 1);
MERGE INTO s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com');
MERGE INTO s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com');
MERGE INTO s2_user (id, `name`, password, display_name, email, is_admin) values (4, 'lucy','123456','lucy','lucy@xx.com', 1);

View File

@@ -91,6 +91,7 @@ create table IF NOT EXISTS s2_user
password varchar(100) null,
email varchar(100) null,
is_admin INT null,
salt varchar(100) null,
PRIMARY KEY (`id`)
);
COMMENT ON TABLE s2_user IS 'user information table';

View File

@@ -77,6 +77,10 @@
<postgresql.version>42.7.1</postgresql.version>
<st.version>4.0.8</st.version>
<duckdb_jdbc.version>0.10.0</duckdb_jdbc.version>
<protobuf-java.version>3.23.1</protobuf-java.version>
<flight-sql.version>15.0.2</flight-sql.version>
<arrow-jdbc.version>15.0.2</arrow-jdbc.version>
<flight-sql-jdbc-driver.version>15.0.2</flight-sql-jdbc-driver.version>
</properties>
<dependencyManagement>