From abf440c4d45f0b0dda47186fd59ea9ca0bba139e Mon Sep 17 00:00:00 2001 From: jipeli <54889677+jipeli@users.noreply.github.com> Date: Wed, 12 Jun 2024 15:02:55 +0800 Subject: [PATCH] (feature)(Headless) arrow flight sql endpoint (#634) (#1136) --- .../authentication/adaptor/UserAdaptor.java | 2 + .../authentication/service/UserService.java | 2 + .../authentication/service/UserStrategy.java | 2 + .../api/authentication/utils/UserHolder.java | 12 +- .../adaptor/DefaultUserAdaptor.java | 24 +- .../service/UserServiceImpl.java | 5 + .../strategy/FakeUserStrategy.java | 5 + .../strategy/HttpHeaderUserStrategy.java | 5 + .../authentication/utils/UserTokenUtils.java | 25 +- headless/core/pom.xml | 10 + headless/server/pom.xml | 16 + .../server/listener/FlightSqlListener.java | 93 +++++ .../server/service/FlightService.java | 11 + .../service/impl/FlightServiceImpl.java | 337 ++++++++++++++++++ .../headless/server/utils/FlightUtils.java | 47 +++ .../supersonic/headless/FlightSqlTest.java | 101 ++++++ .../src/test/resources/db/data-h2.sql | 2 +- .../src/test/resources/db/schema-h2.sql | 1 + pom.xml | 4 + 19 files changed, 697 insertions(+), 7 deletions(-) create mode 100644 headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/FlightSqlListener.java create mode 100644 headless/server/src/main/java/com/tencent/supersonic/headless/server/service/FlightService.java create mode 100644 headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/FlightServiceImpl.java create mode 100644 headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/FlightUtils.java create mode 100644 launchers/standalone/src/test/java/com/tencent/supersonic/headless/FlightSqlTest.java diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/adaptor/UserAdaptor.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/adaptor/UserAdaptor.java index f080a666d..72ea1ee32 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/adaptor/UserAdaptor.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/adaptor/UserAdaptor.java @@ -23,6 +23,8 @@ public interface UserAdaptor { String login(UserReq userReq, HttpServletRequest request); + String login(UserReq userReq, String appKey); + List getUserByOrg(String key); Set getUserAllOrgId(String userName); diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserService.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserService.java index dac5728b7..54d2495bd 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserService.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserService.java @@ -21,6 +21,8 @@ public interface UserService { String login(UserReq userCmd, HttpServletRequest request); + String login(UserReq userCmd, String appKey); + Set getUserAllOrgId(String userName); List getUserByOrg(String key); diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserStrategy.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserStrategy.java index ae9d73a85..f5106da45 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserStrategy.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserStrategy.java @@ -11,4 +11,6 @@ public interface UserStrategy { User findUser(HttpServletRequest request, HttpServletResponse response); + User findUser(String token, String appKey); + } diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/utils/UserHolder.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/utils/UserHolder.java index 1425b1c37..d6838f101 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/utils/UserHolder.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/utils/UserHolder.java @@ -5,10 +5,9 @@ import com.tencent.supersonic.auth.api.authentication.service.UserStrategy; import com.tencent.supersonic.common.pojo.SystemConfig; import com.tencent.supersonic.common.service.SystemConfigService; import com.tencent.supersonic.common.util.ContextUtils; -import org.springframework.util.CollectionUtils; - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.util.CollectionUtils; public final class UserHolder { @@ -20,6 +19,15 @@ public final class UserHolder { public static User findUser(HttpServletRequest request, HttpServletResponse response) { User user = REPO.findUser(request, response); + return getUser(user); + } + + public static User findUser(String token, String appKey) { + User user = REPO.findUser(token, appKey); + return getUser(user); + } + + private static User getUser(User user) { SystemConfigService sysParameterService = ContextUtils.getBean(SystemConfigService.class); SystemConfig systemConfig = sysParameterService.getSystemConfig(); if (!CollectionUtils.isEmpty(systemConfig.getAdmins()) diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java index b3e9a7522..902ae155c 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java @@ -90,25 +90,43 @@ public class DefaultUserAdaptor implements UserAdaptor { @Override public String login(UserReq userReq, HttpServletRequest request) { UserTokenUtils userTokenUtils = ContextUtils.getBean(UserTokenUtils.class); + try { + UserWithPassword user = getUserWithPassword(userReq); + return userTokenUtils.generateToken(user, request); + } catch (Exception e) { + throw new RuntimeException("password encrypt error, please try again"); + } + } + + @Override + public String login(UserReq userReq, String appKey) { + UserTokenUtils userTokenUtils = ContextUtils.getBean(UserTokenUtils.class); + try { + UserWithPassword user = getUserWithPassword(userReq); + return userTokenUtils.generateToken(user, appKey); + } catch (Exception e) { + throw new RuntimeException("password encrypt error, please try again"); + } + } + + private UserWithPassword getUserWithPassword(UserReq userReq) { UserDO userDO = getUser(userReq.getName()); if (userDO == null) { throw new RuntimeException("user not exist,please register"); } - try { String password = AESEncryptionUtil.encrypt(userReq.getPassword(), AESEncryptionUtil.getBytesFromString(userDO.getSalt())); if (userDO.getPassword().equals(password)) { UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), userDO.getDisplayName(), userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin()); - return userTokenUtils.generateToken(user, request); + return user; } else { throw new RuntimeException("password not correct, please try again"); } } catch (Exception e) { throw new RuntimeException("password encrypt error, please try again"); } - } @Override diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java index 013dcc504..7aa8ac9c2 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java @@ -72,4 +72,9 @@ public class UserServiceImpl implements UserService { return ComponentFactory.getUserAdaptor().login(userReq, request); } + @Override + public String login(UserReq userReq, String appKey) { + return ComponentFactory.getUserAdaptor().login(userReq, appKey); + } + } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/FakeUserStrategy.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/FakeUserStrategy.java index 90eff1ed5..818ec779e 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/FakeUserStrategy.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/FakeUserStrategy.java @@ -20,4 +20,9 @@ public class FakeUserStrategy implements UserStrategy { return User.getFakeUser(); } + @Override + public User findUser(String token, String appKey) { + return User.getFakeUser(); + } + } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/HttpHeaderUserStrategy.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/HttpHeaderUserStrategy.java index 8da6fe324..0996b607f 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/HttpHeaderUserStrategy.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/HttpHeaderUserStrategy.java @@ -28,4 +28,9 @@ public class HttpHeaderUserStrategy implements UserStrategy { public User findUser(HttpServletRequest request, HttpServletResponse response) { return userTokenUtils.getUser(request); } + + @Override + public User findUser(String token, String appKey) { + return userTokenUtils.getUser(token, appKey); + } } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java index e9c33790c..a05ab4d0b 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java @@ -37,6 +37,11 @@ public class UserTokenUtils { } public String generateToken(UserWithPassword user, HttpServletRequest request) { + String appKey = getAppKey(request); + return generateToken(user, appKey); + } + + public String generateToken(UserWithPassword user, String appKey) { Map claims = new HashMap<>(5); claims.put(TOKEN_USER_ID, user.getId()); claims.put(TOKEN_USER_NAME, StringUtils.isEmpty(user.getName()) ? "" : user.getName()); @@ -44,7 +49,6 @@ public class UserTokenUtils { claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName()); claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis()); claims.put(TOKEN_IS_ADMIN, user.getIsAdmin()); - String appKey = getAppKey(request); return generate(claims, appKey); } @@ -61,6 +65,15 @@ public class UserTokenUtils { public User getUser(HttpServletRequest request) { String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey()); final Claims claims = getClaims(token, request); + return getUser(claims); + } + + public User getUser(String token, String appKey) { + final Claims claims = getClaims(token, appKey); + return getUser(claims); + } + + private User getUser(Claims claims) { Long userId = Long.parseLong(claims.getOrDefault(TOKEN_USER_ID, 0).toString()); String userName = String.valueOf(claims.get(TOKEN_USER_NAME)); String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); @@ -92,6 +105,16 @@ public class UserTokenUtils { Claims claims; try { String appKey = getAppKey(request); + claims = getClaims(token, appKey); + } catch (Exception e) { + throw new AccessException("parse user info from token failed :" + token); + } + return claims; + } + + private Claims getClaims(String token, String appKey) { + Claims claims; + try { String tokenSecret = getTokenSecret(appKey); claims = Jwts.parser() .setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8)) diff --git a/headless/core/pom.xml b/headless/core/pom.xml index 50ef311d1..0c1b0d0d2 100644 --- a/headless/core/pom.xml +++ b/headless/core/pom.xml @@ -89,6 +89,12 @@ org.apache.calcite.avatica avatica-core ${calcite.avatica.version} + + + protobuf-java + com.google.protobuf + + org.apache.calcite @@ -99,6 +105,10 @@ log4j log4j + + protobuf-java + com.google.protobuf + diff --git a/headless/server/pom.xml b/headless/server/pom.xml index 420ae599e..353db2ccf 100644 --- a/headless/server/pom.xml +++ b/headless/server/pom.xml @@ -119,6 +119,22 @@ ${postgresql.version} + + com.google.protobuf + protobuf-java + ${protobuf-java.version} + + + org.apache.arrow + flight-sql + ${flight-sql.version} + + + org.apache.arrow + arrow-jdbc + ${arrow-jdbc.version} + + \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/FlightSqlListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/FlightSqlListener.java new file mode 100644 index 000000000..c4ff30ced --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/FlightSqlListener.java @@ -0,0 +1,93 @@ +package com.tencent.supersonic.headless.server.listener; + +import com.tencent.supersonic.headless.server.service.FlightService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import lombok.extern.slf4j.Slf4j; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.CommandLineRunner; +import org.springframework.stereotype.Component; + +/** + * arrow flight jdbc server listener + */ +@Component +@Slf4j +public class FlightSqlListener implements CommandLineRunner { + + @Value("${s2.flightSql.enable:false}") + private Boolean enable = false; + @Value("${s2.flightSql.host:localhost}") + private String host = "localhost"; + @Value("${s2.flightSql.port:9081}") + private Integer port = 9081; + @Value("${s2.flightSql.executor:4}") + private Integer executor = 4; + @Value("${s2.flightSql.queue:128}") + private Integer queue = 128; + @Value("${s2.flightSql.expireMinute:10}") + private Integer expireMinute = 10; + + private final FlightService flightService; + private ExecutorService executorService; + private FlightServer flightServer; + private BufferAllocator allocator; + + public FlightSqlListener(FlightService flightService) { + this.allocator = new RootAllocator(); + this.flightService = flightService; + this.flightService.setLocation(host, port); + executorService = Executors.newFixedThreadPool(executor); + this.flightService.setExecutorService(executorService, queue, expireMinute); + Location listenLocation = Location.forGrpcInsecure(host, port); + flightServer = FlightServer.builder(allocator, listenLocation, this.flightService) + .build(); + } + + public String getHost() { + return host; + } + + public Integer getPort() { + return port; + } + + public void startServer() { + try { + log.info("Arrow Flight JDBC server started on {} {}", host, port); + flightServer.start(); + } catch (Exception e) { + log.error("FlightSqlListener start error {}", e); + } + + } + + @Override + public void run(String... args) throws Exception { + if (enable) { + new Thread() { + @Override + public void run() { + try { + startServer(); + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + flightServer.close(); + allocator.close(); + } catch (Exception e) { + log.error("flightServer close error {}", e); + } + })); + //flightServer.awaitTermination(); + } catch (Exception e) { + log.error("run error {}", e); + } + } + }.start(); + } + } +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/FlightService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/FlightService.java new file mode 100644 index 000000000..572f9fdbf --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/FlightService.java @@ -0,0 +1,11 @@ +package com.tencent.supersonic.headless.server.service; + +import java.util.concurrent.ExecutorService; +import org.apache.arrow.flight.sql.FlightSqlProducer; + +public interface FlightService extends FlightSqlProducer { + + void setLocation(String host, Integer port); + + void setExecutorService(ExecutorService executorService, Integer queue, Integer expireMinute); +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/FlightServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/FlightServiceImpl.java new file mode 100644 index 000000000..1c055a9d9 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/FlightServiceImpl.java @@ -0,0 +1,337 @@ +package com.tencent.supersonic.headless.server.service.impl; + +import static com.google.protobuf.Any.pack; +import static com.google.protobuf.ByteString.copyFrom; +import static java.util.Collections.singletonList; +import static java.util.UUID.randomUUID; +import static org.apache.arrow.adapter.jdbc.JdbcToArrow.sqlToArrowVectorIterator; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; +import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authentication.request.UserReq; +import com.tencent.supersonic.auth.api.authentication.service.UserService; +import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.headless.api.pojo.Param; +import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; +import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; +import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; +import com.tencent.supersonic.headless.server.service.FlightService; +import com.tencent.supersonic.headless.server.service.QueryService; +import com.tencent.supersonic.headless.server.utils.FlightUtils; +import java.nio.charset.StandardCharsets; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Types; +import java.util.Arrays; +import java.util.Calendar; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import javax.sql.RowSetMetaData; +import javax.sql.rowset.CachedRowSet; +import javax.sql.rowset.RowSetFactory; +import javax.sql.rowset.RowSetMetaDataImpl; +import javax.sql.rowset.RowSetProvider; +import lombok.extern.slf4j.Slf4j; +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightConstants; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.BasicFlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.commons.lang3.StringUtils; +import org.springframework.stereotype.Service; + +/** + * arrow flight FlightSqlProducer + */ +@Slf4j +@Service("FlightService") +public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightService { + + private String host; + private Integer port; + private ExecutorService executorService; + private Cache preparedStatementCache; + private final String dataSetIdHeaderKey = "dataSetId"; + private final String nameHeaderKey = "name"; + private final String passwordHeaderKey = "password"; + private final Calendar defaultCalendar = JdbcToArrowUtils.getUtcCalendar(); + private final QueryService queryService; + private final AuthenticationConfig authenticationConfig; + private final UserService userService; + + public FlightServiceImpl(QueryService queryService, + AuthenticationConfig authenticationConfig, + UserService userService) { + this.queryService = queryService; + this.authenticationConfig = authenticationConfig; + + this.userService = userService; + } + + public void setLocation(String host, Integer port) { + this.host = host; + this.port = port; + } + + @Override + public void setExecutorService(ExecutorService executorService, Integer queue, Integer expireMinute) { + this.executorService = executorService; + this.preparedStatementCache = + CacheBuilder.newBuilder() + .maximumSize(queue) + .expireAfterWrite(expireMinute, TimeUnit.MINUTES) + .build(); + } + + @Override + public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) { + return super.getFlightInfo(callContext, flightDescriptor); + } + + @Override + public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, + final ServerStreamListener listener) { + final ByteString handle = ticketStatementQuery.getStatementHandle(); + log.info("getStreamStatement {} ", handle); + executeQuery(handle, listener); + } + + @Override + public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context, + final FlightDescriptor descriptor) { + try { + ByteString preparedStatementHandle = addPrepared(context, request.getQuery()); + TicketStatementQuery ticket = TicketStatementQuery.newBuilder() + .setStatementHandle(preparedStatementHandle) + .build(); + return getFlightInfoForSchema(ticket, descriptor, null); + } catch (Exception e) { + log.error("getFlightInfoStatement error {}", e); + } + return null; + } + + @Override + public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, + final ServerStreamListener listener) { + log.info("getStreamPreparedStatement {}", command.getPreparedStatementHandle()); + executeQuery(command.getPreparedStatementHandle(), listener); + } + + private void executeQuery(ByteString hander, final ServerStreamListener listener) { + SemanticQueryReq semanticQueryReq = preparedStatementCache.getIfPresent(hander); + if (Objects.isNull(semanticQueryReq)) { + listener.error(CallStatus.INTERNAL + .withDescription("Failed to get prepared statement: empty") + .toRuntimeException()); + log.error("getStreamPreparedStatement error {}", hander); + listener.completed(); + return; + } + executorService.submit(() -> { + BufferAllocator rootAllocator = new RootAllocator(); + try { + Optional authOpt = semanticQueryReq.getParams().stream() + .filter(p -> p.getName().equals(authenticationConfig.getTokenHttpHeaderKey())).findFirst(); + if (authOpt.isPresent()) { + User user = UserHolder.findUser(authOpt.get().getValue(), + authenticationConfig.getTokenHttpHeaderAppKey()); + SemanticQueryResp resp = queryService.queryByReq(semanticQueryReq, user); + ResultSet resultSet = semanticQueryRespToResultSet(resp, semanticQueryReq.getDataSetId()); + final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), defaultCalendar); + try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) { + final VectorLoader loader = new VectorLoader(vectorSchemaRoot); + listener.start(vectorSchemaRoot); + final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator); + while (iterator.hasNext()) { + final VectorSchemaRoot batch = iterator.next(); + if (batch.getRowCount() == 0) { + break; + } + final VectorUnloader unloader = new VectorUnloader(batch); + loader.load(unloader.getRecordBatch()); + listener.putNext(); + vectorSchemaRoot.clear(); + } + + listener.putNext(); + } + + } + } catch (Exception e) { + listener.error(CallStatus.INTERNAL + .withDescription(String.format("Failed to get exec statement %s", e.getMessage())) + .toRuntimeException()); + log.error("getStreamPreparedStatement error {}", hander); + } finally { + preparedStatementCache.invalidate(hander); + listener.completed(); + rootAllocator.close(); + } + }); + } + + @Override + public void closePreparedStatement(final ActionClosePreparedStatementRequest request, final CallContext context, + final StreamListener listener) { + log.info("closePreparedStatement {}", request.getPreparedStatementHandle()); + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command, + final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(command, descriptor, null); + } + + @Override + public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext context, + final StreamListener listener) { + prepared(request, context, listener); + } + + private ByteString addPrepared(final CallContext context, String query) throws Exception { + if (Arrays.asList(dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey).stream() + .anyMatch(h -> !context.getMiddleware(FlightConstants.HEADER_KEY).headers().containsKey(h))) { + throw new Exception(String.format("Failed to create prepared statement: HeaderCallOption miss %s %s %s", + dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey)); + } + Long dataSetId = Long.valueOf( + context.getMiddleware(FlightConstants.HEADER_KEY).headers().get(dataSetIdHeaderKey)); + if (StringUtils.isBlank(query)) { + throw new Exception("Failed to create prepared statement: query is empty"); + } + try { + String auth = getUserAuth(context.getMiddleware(FlightConstants.HEADER_KEY).headers()); + if (StringUtils.isBlank(auth)) { + throw new Exception("auth empty"); + } + final ByteString preparedStatementHandle = copyFrom( + randomUUID().toString().getBytes(StandardCharsets.UTF_8)); + QuerySqlReq querySqlReq = new QuerySqlReq(); + querySqlReq.setDataSetId(dataSetId); + querySqlReq.setSql(query); + querySqlReq.setParams(Arrays.asList(new Param(authenticationConfig.getTokenHttpHeaderKey(), auth))); + preparedStatementCache.put(preparedStatementHandle, querySqlReq); + log.info("createPreparedStatement {} {} {} ", preparedStatementHandle, dataSetId, query); + return preparedStatementHandle; + } catch (Exception e) { + throw e; + } + } + + private void prepared(final ActionCreatePreparedStatementRequest request, final CallContext context, + final StreamListener listener) { + try { + ByteString preparedStatementHandle = addPrepared(context, request.getQuery()); + final ActionCreatePreparedStatementResult result = ActionCreatePreparedStatementResult.newBuilder() + .setDatasetSchema(ByteString.EMPTY) + .setParameterSchema(ByteString.empty()) + .setPreparedStatementHandle(preparedStatementHandle) + .build(); + listener.onNext(new Result(pack(result).toByteArray())); + } catch (Exception e) { + listener.onError(CallStatus.INTERNAL + .withDescription(String.format("Failed to create prepared statement: %s", e.getMessage())) + .toRuntimeException()); + } finally { + listener.onCompleted(); + } + } + + @Override + protected List determineEndpoints(T t, FlightDescriptor flightDescriptor, + Schema schema) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + private FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, + final Schema schema) { + final Ticket ticket = new Ticket(pack(request).toByteArray()); + Location listenLocation = Location.forGrpcInsecure(host, port); + final List endpoints = singletonList(new FlightEndpoint(ticket, listenLocation)); + + return new FlightInfo(schema, descriptor, endpoints, -1, -1); + } + + private String getUserAuth(CallHeaders callHeaders) throws Exception { + + UserReq userReq = new UserReq(); + userReq.setName(callHeaders.get(nameHeaderKey)); + userReq.setPassword(callHeaders.get(passwordHeaderKey)); + if (StringUtils.isBlank(userReq.getName()) || StringUtils.isBlank(userReq.getPassword())) { + throw new Exception("name or password is empty"); + } + String auth = userService.login(userReq, authenticationConfig.getTokenDefaultAppKey()); + return auth; + } + + private ResultSet semanticQueryRespToResultSet(SemanticQueryResp resp, Long dataSetId) throws SQLException { + RowSetFactory factory = RowSetProvider.newFactory(); + CachedRowSet rowset = factory.createCachedRowSet(); + RowSetMetaData rowSetMetaData = new RowSetMetaDataImpl(); + int columnNum = resp.getColumns().size(); + rowSetMetaData.setColumnCount(columnNum); + for (int i = 1; i <= columnNum; i++) { + String columnName = resp.getColumns().get(i - 1).getNameEn(); + rowSetMetaData.setColumnName(i, columnName); + Optional> valOpt = resp.getResultList().stream() + .filter(r -> r.containsKey(columnName) && Objects.nonNull(r.get(columnName))).findFirst(); + if (valOpt.isPresent()) { + int type = FlightUtils.resolveType(valOpt.get()); + rowSetMetaData.setColumnType(i, type); + rowSetMetaData.setNullable(i, FlightUtils.isNullable(type)); + } else { + rowSetMetaData.setNullable(i, ResultSetMetaData.columnNullable); + rowSetMetaData.setColumnType(i, Types.VARCHAR); + } + rowSetMetaData.setCatalogName(i, String.valueOf(dataSetId)); + rowSetMetaData.setSchemaName(i, dataSetIdHeaderKey); + } + rowset.setMetaData(rowSetMetaData); + for (Map row : resp.getResultList()) { + rowset.moveToInsertRow(); + for (int i = 1; i <= columnNum; i++) { + String columnName = resp.getColumns().get(i - 1).getNameEn(); + if (row.containsKey(columnName)) { + rowset.updateObject(i, row.get(columnName)); + } else { + rowset.updateObject(i, null); + } + } + rowset.insertRow(); + rowset.moveToCurrentRow(); + } + return rowset; + } +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/FlightUtils.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/FlightUtils.java new file mode 100644 index 000000000..e185f0a64 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/FlightUtils.java @@ -0,0 +1,47 @@ +package com.tencent.supersonic.headless.server.utils; + +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Types; +import java.util.regex.Pattern; + +/** + * tools for arrow flight sql + */ +public class FlightUtils { + + public static int resolveType(Object value) { + if (value instanceof Long) { + return Types.BIGINT; + } + if (value instanceof Integer) { + return Types.INTEGER; + } + if (value instanceof Double) { + return Types.DOUBLE; + } + if (value instanceof String) { + String val = String.valueOf(value); + if (Pattern.matches("^\\d+$", val)) { + return Types.BIGINT; + } else if (Pattern.matches("^\\d+\\.\\d+$", val)) { + return Types.DECIMAL; + } else if (Pattern.matches("^\\d{4}-\\d{2}-\\d{2}$", val)) { + return Types.DATE; + } else if (Pattern.matches("^\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}$", val)) { + return Types.TIME; + } + } + return Types.VARCHAR; + } + + public static int isNullable(int sqlType) throws SQLException { + switch (sqlType) { + case Types.VARCHAR: + case Types.DECIMAL: + return ResultSetMetaData.columnNullable; + default: + return ResultSetMetaData.columnNullableUnknown; + } + } +} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/FlightSqlTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/FlightSqlTest.java new file mode 100644 index 000000000..fea1acc28 --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/FlightSqlTest.java @@ -0,0 +1,101 @@ +package com.tencent.supersonic.headless; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.auth.authentication.strategy.FakeUserStrategy; +import com.tencent.supersonic.headless.server.listener.FlightSqlListener; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.FlightCallHeaders; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.HeaderCallOption; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; + +public class FlightSqlTest extends BaseTest { + + + @Autowired + private FlightSqlListener flightSqlListener; + @Autowired + private FakeUserStrategy fakeUserStrategy; + + @Test + void test01() throws Exception { + String host = flightSqlListener.getHost(); + Integer port = flightSqlListener.getPort(); + UserHolder.setStrategy(fakeUserStrategy); + flightSqlListener.startServer(); + FlightSqlClient sqlClient = new FlightSqlClient( + FlightClient.builder(new RootAllocator(Integer.MAX_VALUE), Location.forGrpcInsecure(host, port)) + .build()); + + CallHeaders headers = new FlightCallHeaders(); + headers.insert("dataSetId", "1"); + headers.insert("name", "admin"); + headers.insert("password", "admin"); + HeaderCallOption headerOption = new HeaderCallOption(headers); + try (final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门", + headerOption)) { + final FlightInfo info = preparedStatement.execute(); + FlightStream stream = sqlClient.getStream(info + .getEndpoints() + .get(0).getTicket()); + int rowCnt = 0; + int colCnt = 0; + while (stream.next()) { + if (stream.getRoot().getRowCount() > 0) { + colCnt = stream.getRoot().getFieldVectors().size(); + rowCnt += stream.getRoot().getRowCount(); + } + } + assertEquals(2, colCnt); + assertTrue(rowCnt > 0); + } catch (Exception e) { + e.printStackTrace(); + } + } + + @Test + void test02() throws Exception { + String host = flightSqlListener.getHost(); + Integer port = flightSqlListener.getPort(); + UserHolder.setStrategy(fakeUserStrategy); + flightSqlListener.startServer(); + FlightSqlClient sqlClient = new FlightSqlClient( + FlightClient.builder(new RootAllocator(Integer.MAX_VALUE), Location.forGrpcInsecure(host, port)) + .build()); + + CallHeaders headers = new FlightCallHeaders(); + headers.insert("dataSetId", "1"); + headers.insert("name", "admin"); + headers.insert("password", "admin"); + HeaderCallOption headerOption = new HeaderCallOption(headers); + try { + FlightInfo flightInfo = sqlClient.execute("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门", + headerOption); + FlightStream stream = sqlClient.getStream(flightInfo + .getEndpoints() + .get(0).getTicket()); + int rowCnt = 0; + int colCnt = 0; + while (stream.next()) { + if (stream.getRoot().getRowCount() > 0) { + colCnt = stream.getRoot().getFieldVectors().size(); + rowCnt += stream.getRoot().getRowCount(); + } + } + assertEquals(2, colCnt); + assertTrue(rowCnt > 0); + } catch (Exception e) { + e.printStackTrace(); + } + } +} diff --git a/launchers/standalone/src/test/resources/db/data-h2.sql b/launchers/standalone/src/test/resources/db/data-h2.sql index a3321f732..72de65847 100644 --- a/launchers/standalone/src/test/resources/db/data-h2.sql +++ b/launchers/standalone/src/test/resources/db/data-h2.sql @@ -1,5 +1,5 @@ -- sample user -MERGE INTO s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); +MERGE INTO s2_user (id, `name`, password, salt, display_name, email, is_admin) values (1, 'admin','c3VwZXJzb25pY0BiaWNvbTD12g9wGXESwL7+o7xUW90=','jGl25bVBBBW96Qi9Te4V3w==','admin','admin@xx.com', 1); MERGE INTO s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); MERGE INTO s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); MERGE INTO s2_user (id, `name`, password, display_name, email, is_admin) values (4, 'lucy','123456','lucy','lucy@xx.com', 1); diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index 7bbe98b65..feb056c38 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -91,6 +91,7 @@ create table IF NOT EXISTS s2_user password varchar(100) null, email varchar(100) null, is_admin INT null, + salt varchar(100) null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; diff --git a/pom.xml b/pom.xml index 2650dfe89..2d02b8b94 100644 --- a/pom.xml +++ b/pom.xml @@ -77,6 +77,10 @@ 42.7.1 4.0.8 0.10.0 + 3.23.1 + 15.0.2 + 15.0.2 + 15.0.2