mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
@@ -23,6 +23,8 @@ public interface UserAdaptor {
|
||||
|
||||
String login(UserReq userReq, HttpServletRequest request);
|
||||
|
||||
String login(UserReq userReq, String appKey);
|
||||
|
||||
List<User> getUserByOrg(String key);
|
||||
|
||||
Set<String> getUserAllOrgId(String userName);
|
||||
|
||||
@@ -21,6 +21,8 @@ public interface UserService {
|
||||
|
||||
String login(UserReq userCmd, HttpServletRequest request);
|
||||
|
||||
String login(UserReq userCmd, String appKey);
|
||||
|
||||
Set<String> getUserAllOrgId(String userName);
|
||||
|
||||
List<User> getUserByOrg(String key);
|
||||
|
||||
@@ -11,4 +11,6 @@ public interface UserStrategy {
|
||||
|
||||
User findUser(HttpServletRequest request, HttpServletResponse response);
|
||||
|
||||
User findUser(String token, String appKey);
|
||||
|
||||
}
|
||||
|
||||
@@ -5,10 +5,9 @@ import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
|
||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
public final class UserHolder {
|
||||
|
||||
@@ -20,6 +19,15 @@ public final class UserHolder {
|
||||
|
||||
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
|
||||
User user = REPO.findUser(request, response);
|
||||
return getUser(user);
|
||||
}
|
||||
|
||||
public static User findUser(String token, String appKey) {
|
||||
User user = REPO.findUser(token, appKey);
|
||||
return getUser(user);
|
||||
}
|
||||
|
||||
private static User getUser(User user) {
|
||||
SystemConfigService sysParameterService = ContextUtils.getBean(SystemConfigService.class);
|
||||
SystemConfig systemConfig = sysParameterService.getSystemConfig();
|
||||
if (!CollectionUtils.isEmpty(systemConfig.getAdmins())
|
||||
|
||||
@@ -90,25 +90,43 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
||||
@Override
|
||||
public String login(UserReq userReq, HttpServletRequest request) {
|
||||
UserTokenUtils userTokenUtils = ContextUtils.getBean(UserTokenUtils.class);
|
||||
try {
|
||||
UserWithPassword user = getUserWithPassword(userReq);
|
||||
return userTokenUtils.generateToken(user, request);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("password encrypt error, please try again");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String login(UserReq userReq, String appKey) {
|
||||
UserTokenUtils userTokenUtils = ContextUtils.getBean(UserTokenUtils.class);
|
||||
try {
|
||||
UserWithPassword user = getUserWithPassword(userReq);
|
||||
return userTokenUtils.generateToken(user, appKey);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("password encrypt error, please try again");
|
||||
}
|
||||
}
|
||||
|
||||
private UserWithPassword getUserWithPassword(UserReq userReq) {
|
||||
UserDO userDO = getUser(userReq.getName());
|
||||
if (userDO == null) {
|
||||
throw new RuntimeException("user not exist,please register");
|
||||
}
|
||||
|
||||
try {
|
||||
String password = AESEncryptionUtil.encrypt(userReq.getPassword(),
|
||||
AESEncryptionUtil.getBytesFromString(userDO.getSalt()));
|
||||
if (userDO.getPassword().equals(password)) {
|
||||
UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
|
||||
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
||||
return userTokenUtils.generateToken(user, request);
|
||||
return user;
|
||||
} else {
|
||||
throw new RuntimeException("password not correct, please try again");
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("password encrypt error, please try again");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -72,4 +72,9 @@ public class UserServiceImpl implements UserService {
|
||||
return ComponentFactory.getUserAdaptor().login(userReq, request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String login(UserReq userReq, String appKey) {
|
||||
return ComponentFactory.getUserAdaptor().login(userReq, appKey);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -20,4 +20,9 @@ public class FakeUserStrategy implements UserStrategy {
|
||||
return User.getFakeUser();
|
||||
}
|
||||
|
||||
@Override
|
||||
public User findUser(String token, String appKey) {
|
||||
return User.getFakeUser();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -28,4 +28,9 @@ public class HttpHeaderUserStrategy implements UserStrategy {
|
||||
public User findUser(HttpServletRequest request, HttpServletResponse response) {
|
||||
return userTokenUtils.getUser(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public User findUser(String token, String appKey) {
|
||||
return userTokenUtils.getUser(token, appKey);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,11 @@ public class UserTokenUtils {
|
||||
}
|
||||
|
||||
public String generateToken(UserWithPassword user, HttpServletRequest request) {
|
||||
String appKey = getAppKey(request);
|
||||
return generateToken(user, appKey);
|
||||
}
|
||||
|
||||
public String generateToken(UserWithPassword user, String appKey) {
|
||||
Map<String, Object> claims = new HashMap<>(5);
|
||||
claims.put(TOKEN_USER_ID, user.getId());
|
||||
claims.put(TOKEN_USER_NAME, StringUtils.isEmpty(user.getName()) ? "" : user.getName());
|
||||
@@ -44,7 +49,6 @@ public class UserTokenUtils {
|
||||
claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName());
|
||||
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
|
||||
claims.put(TOKEN_IS_ADMIN, user.getIsAdmin());
|
||||
String appKey = getAppKey(request);
|
||||
return generate(claims, appKey);
|
||||
}
|
||||
|
||||
@@ -61,6 +65,15 @@ public class UserTokenUtils {
|
||||
public User getUser(HttpServletRequest request) {
|
||||
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
|
||||
final Claims claims = getClaims(token, request);
|
||||
return getUser(claims);
|
||||
}
|
||||
|
||||
public User getUser(String token, String appKey) {
|
||||
final Claims claims = getClaims(token, appKey);
|
||||
return getUser(claims);
|
||||
}
|
||||
|
||||
private User getUser(Claims claims) {
|
||||
Long userId = Long.parseLong(claims.getOrDefault(TOKEN_USER_ID, 0).toString());
|
||||
String userName = String.valueOf(claims.get(TOKEN_USER_NAME));
|
||||
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
|
||||
@@ -92,6 +105,16 @@ public class UserTokenUtils {
|
||||
Claims claims;
|
||||
try {
|
||||
String appKey = getAppKey(request);
|
||||
claims = getClaims(token, appKey);
|
||||
} catch (Exception e) {
|
||||
throw new AccessException("parse user info from token failed :" + token);
|
||||
}
|
||||
return claims;
|
||||
}
|
||||
|
||||
private Claims getClaims(String token, String appKey) {
|
||||
Claims claims;
|
||||
try {
|
||||
String tokenSecret = getTokenSecret(appKey);
|
||||
claims = Jwts.parser()
|
||||
.setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
|
||||
|
||||
@@ -89,6 +89,12 @@
|
||||
<groupId>org.apache.calcite.avatica</groupId>
|
||||
<artifactId>avatica-core</artifactId>
|
||||
<version>${calcite.avatica.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.calcite</groupId>
|
||||
@@ -99,6 +105,10 @@
|
||||
<artifactId>log4j</artifactId>
|
||||
<groupId>log4j</groupId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
|
||||
@@ -119,6 +119,22 @@
|
||||
<version>${postgresql.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>${protobuf-java.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>flight-sql</artifactId>
|
||||
<version>${flight-sql.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-jdbc</artifactId>
|
||||
<version>${arrow-jdbc.version}</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@@ -0,0 +1,93 @@
|
||||
package com.tencent.supersonic.headless.server.listener;
|
||||
|
||||
import com.tencent.supersonic.headless.server.service.FlightService;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.arrow.flight.FlightServer;
|
||||
import org.apache.arrow.flight.Location;
|
||||
import org.apache.arrow.memory.BufferAllocator;
|
||||
import org.apache.arrow.memory.RootAllocator;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* arrow flight jdbc server listener
|
||||
*/
|
||||
@Component
|
||||
@Slf4j
|
||||
public class FlightSqlListener implements CommandLineRunner {
|
||||
|
||||
@Value("${s2.flightSql.enable:false}")
|
||||
private Boolean enable = false;
|
||||
@Value("${s2.flightSql.host:localhost}")
|
||||
private String host = "localhost";
|
||||
@Value("${s2.flightSql.port:9081}")
|
||||
private Integer port = 9081;
|
||||
@Value("${s2.flightSql.executor:4}")
|
||||
private Integer executor = 4;
|
||||
@Value("${s2.flightSql.queue:128}")
|
||||
private Integer queue = 128;
|
||||
@Value("${s2.flightSql.expireMinute:10}")
|
||||
private Integer expireMinute = 10;
|
||||
|
||||
private final FlightService flightService;
|
||||
private ExecutorService executorService;
|
||||
private FlightServer flightServer;
|
||||
private BufferAllocator allocator;
|
||||
|
||||
public FlightSqlListener(FlightService flightService) {
|
||||
this.allocator = new RootAllocator();
|
||||
this.flightService = flightService;
|
||||
this.flightService.setLocation(host, port);
|
||||
executorService = Executors.newFixedThreadPool(executor);
|
||||
this.flightService.setExecutorService(executorService, queue, expireMinute);
|
||||
Location listenLocation = Location.forGrpcInsecure(host, port);
|
||||
flightServer = FlightServer.builder(allocator, listenLocation, this.flightService)
|
||||
.build();
|
||||
}
|
||||
|
||||
public String getHost() {
|
||||
return host;
|
||||
}
|
||||
|
||||
public Integer getPort() {
|
||||
return port;
|
||||
}
|
||||
|
||||
public void startServer() {
|
||||
try {
|
||||
log.info("Arrow Flight JDBC server started on {} {}", host, port);
|
||||
flightServer.start();
|
||||
} catch (Exception e) {
|
||||
log.error("FlightSqlListener start error {}", e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(String... args) throws Exception {
|
||||
if (enable) {
|
||||
new Thread() {
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
startServer();
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
flightServer.close();
|
||||
allocator.close();
|
||||
} catch (Exception e) {
|
||||
log.error("flightServer close error {}", e);
|
||||
}
|
||||
}));
|
||||
//flightServer.awaitTermination();
|
||||
} catch (Exception e) {
|
||||
log.error("run error {}", e);
|
||||
}
|
||||
}
|
||||
}.start();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.headless.server.service;
|
||||
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import org.apache.arrow.flight.sql.FlightSqlProducer;
|
||||
|
||||
public interface FlightService extends FlightSqlProducer {
|
||||
|
||||
void setLocation(String host, Integer port);
|
||||
|
||||
void setExecutorService(ExecutorService executorService, Integer queue, Integer expireMinute);
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
package com.tencent.supersonic.headless.server.service.impl;
|
||||
|
||||
import static com.google.protobuf.Any.pack;
|
||||
import static com.google.protobuf.ByteString.copyFrom;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.UUID.randomUUID;
|
||||
import static org.apache.arrow.adapter.jdbc.JdbcToArrow.sqlToArrowVectorIterator;
|
||||
import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema;
|
||||
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.Message;
|
||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.headless.api.pojo.Param;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.server.service.FlightService;
|
||||
import com.tencent.supersonic.headless.server.service.QueryService;
|
||||
import com.tencent.supersonic.headless.server.utils.FlightUtils;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.ResultSetMetaData;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Types;
|
||||
import java.util.Arrays;
|
||||
import java.util.Calendar;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import javax.sql.RowSetMetaData;
|
||||
import javax.sql.rowset.CachedRowSet;
|
||||
import javax.sql.rowset.RowSetFactory;
|
||||
import javax.sql.rowset.RowSetMetaDataImpl;
|
||||
import javax.sql.rowset.RowSetProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.arrow.adapter.jdbc.ArrowVectorIterator;
|
||||
import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils;
|
||||
import org.apache.arrow.flight.CallHeaders;
|
||||
import org.apache.arrow.flight.CallStatus;
|
||||
import org.apache.arrow.flight.FlightConstants;
|
||||
import org.apache.arrow.flight.FlightDescriptor;
|
||||
import org.apache.arrow.flight.FlightEndpoint;
|
||||
import org.apache.arrow.flight.FlightInfo;
|
||||
import org.apache.arrow.flight.Location;
|
||||
import org.apache.arrow.flight.Result;
|
||||
import org.apache.arrow.flight.Ticket;
|
||||
import org.apache.arrow.flight.sql.BasicFlightSqlProducer;
|
||||
import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest;
|
||||
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest;
|
||||
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult;
|
||||
import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery;
|
||||
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery;
|
||||
import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery;
|
||||
import org.apache.arrow.memory.BufferAllocator;
|
||||
import org.apache.arrow.memory.RootAllocator;
|
||||
import org.apache.arrow.vector.VectorLoader;
|
||||
import org.apache.arrow.vector.VectorSchemaRoot;
|
||||
import org.apache.arrow.vector.VectorUnloader;
|
||||
import org.apache.arrow.vector.types.pojo.Schema;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* arrow flight FlightSqlProducer
|
||||
*/
|
||||
@Slf4j
|
||||
@Service("FlightService")
|
||||
public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightService {
|
||||
|
||||
private String host;
|
||||
private Integer port;
|
||||
private ExecutorService executorService;
|
||||
private Cache<ByteString, SemanticQueryReq> preparedStatementCache;
|
||||
private final String dataSetIdHeaderKey = "dataSetId";
|
||||
private final String nameHeaderKey = "name";
|
||||
private final String passwordHeaderKey = "password";
|
||||
private final Calendar defaultCalendar = JdbcToArrowUtils.getUtcCalendar();
|
||||
private final QueryService queryService;
|
||||
private final AuthenticationConfig authenticationConfig;
|
||||
private final UserService userService;
|
||||
|
||||
public FlightServiceImpl(QueryService queryService,
|
||||
AuthenticationConfig authenticationConfig,
|
||||
UserService userService) {
|
||||
this.queryService = queryService;
|
||||
this.authenticationConfig = authenticationConfig;
|
||||
|
||||
this.userService = userService;
|
||||
}
|
||||
|
||||
public void setLocation(String host, Integer port) {
|
||||
this.host = host;
|
||||
this.port = port;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setExecutorService(ExecutorService executorService, Integer queue, Integer expireMinute) {
|
||||
this.executorService = executorService;
|
||||
this.preparedStatementCache =
|
||||
CacheBuilder.newBuilder()
|
||||
.maximumSize(queue)
|
||||
.expireAfterWrite(expireMinute, TimeUnit.MINUTES)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) {
|
||||
return super.getFlightInfo(callContext, flightDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context,
|
||||
final ServerStreamListener listener) {
|
||||
final ByteString handle = ticketStatementQuery.getStatementHandle();
|
||||
log.info("getStreamStatement {} ", handle);
|
||||
executeQuery(handle, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context,
|
||||
final FlightDescriptor descriptor) {
|
||||
try {
|
||||
ByteString preparedStatementHandle = addPrepared(context, request.getQuery());
|
||||
TicketStatementQuery ticket = TicketStatementQuery.newBuilder()
|
||||
.setStatementHandle(preparedStatementHandle)
|
||||
.build();
|
||||
return getFlightInfoForSchema(ticket, descriptor, null);
|
||||
} catch (Exception e) {
|
||||
log.error("getFlightInfoStatement error {}", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context,
|
||||
final ServerStreamListener listener) {
|
||||
log.info("getStreamPreparedStatement {}", command.getPreparedStatementHandle());
|
||||
executeQuery(command.getPreparedStatementHandle(), listener);
|
||||
}
|
||||
|
||||
private void executeQuery(ByteString hander, final ServerStreamListener listener) {
|
||||
SemanticQueryReq semanticQueryReq = preparedStatementCache.getIfPresent(hander);
|
||||
if (Objects.isNull(semanticQueryReq)) {
|
||||
listener.error(CallStatus.INTERNAL
|
||||
.withDescription("Failed to get prepared statement: empty")
|
||||
.toRuntimeException());
|
||||
log.error("getStreamPreparedStatement error {}", hander);
|
||||
listener.completed();
|
||||
return;
|
||||
}
|
||||
executorService.submit(() -> {
|
||||
BufferAllocator rootAllocator = new RootAllocator();
|
||||
try {
|
||||
Optional<Param> authOpt = semanticQueryReq.getParams().stream()
|
||||
.filter(p -> p.getName().equals(authenticationConfig.getTokenHttpHeaderKey())).findFirst();
|
||||
if (authOpt.isPresent()) {
|
||||
User user = UserHolder.findUser(authOpt.get().getValue(),
|
||||
authenticationConfig.getTokenHttpHeaderAppKey());
|
||||
SemanticQueryResp resp = queryService.queryByReq(semanticQueryReq, user);
|
||||
ResultSet resultSet = semanticQueryRespToResultSet(resp, semanticQueryReq.getDataSetId());
|
||||
final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), defaultCalendar);
|
||||
try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
|
||||
final VectorLoader loader = new VectorLoader(vectorSchemaRoot);
|
||||
listener.start(vectorSchemaRoot);
|
||||
final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator);
|
||||
while (iterator.hasNext()) {
|
||||
final VectorSchemaRoot batch = iterator.next();
|
||||
if (batch.getRowCount() == 0) {
|
||||
break;
|
||||
}
|
||||
final VectorUnloader unloader = new VectorUnloader(batch);
|
||||
loader.load(unloader.getRecordBatch());
|
||||
listener.putNext();
|
||||
vectorSchemaRoot.clear();
|
||||
}
|
||||
|
||||
listener.putNext();
|
||||
}
|
||||
|
||||
}
|
||||
} catch (Exception e) {
|
||||
listener.error(CallStatus.INTERNAL
|
||||
.withDescription(String.format("Failed to get exec statement %s", e.getMessage()))
|
||||
.toRuntimeException());
|
||||
log.error("getStreamPreparedStatement error {}", hander);
|
||||
} finally {
|
||||
preparedStatementCache.invalidate(hander);
|
||||
listener.completed();
|
||||
rootAllocator.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void closePreparedStatement(final ActionClosePreparedStatementRequest request, final CallContext context,
|
||||
final StreamListener<Result> listener) {
|
||||
log.info("closePreparedStatement {}", request.getPreparedStatementHandle());
|
||||
listener.onCompleted();
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command,
|
||||
final CallContext context,
|
||||
final FlightDescriptor descriptor) {
|
||||
return getFlightInfoForSchema(command, descriptor, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext context,
|
||||
final StreamListener<Result> listener) {
|
||||
prepared(request, context, listener);
|
||||
}
|
||||
|
||||
private ByteString addPrepared(final CallContext context, String query) throws Exception {
|
||||
if (Arrays.asList(dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey).stream()
|
||||
.anyMatch(h -> !context.getMiddleware(FlightConstants.HEADER_KEY).headers().containsKey(h))) {
|
||||
throw new Exception(String.format("Failed to create prepared statement: HeaderCallOption miss %s %s %s",
|
||||
dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey));
|
||||
}
|
||||
Long dataSetId = Long.valueOf(
|
||||
context.getMiddleware(FlightConstants.HEADER_KEY).headers().get(dataSetIdHeaderKey));
|
||||
if (StringUtils.isBlank(query)) {
|
||||
throw new Exception("Failed to create prepared statement: query is empty");
|
||||
}
|
||||
try {
|
||||
String auth = getUserAuth(context.getMiddleware(FlightConstants.HEADER_KEY).headers());
|
||||
if (StringUtils.isBlank(auth)) {
|
||||
throw new Exception("auth empty");
|
||||
}
|
||||
final ByteString preparedStatementHandle = copyFrom(
|
||||
randomUUID().toString().getBytes(StandardCharsets.UTF_8));
|
||||
QuerySqlReq querySqlReq = new QuerySqlReq();
|
||||
querySqlReq.setDataSetId(dataSetId);
|
||||
querySqlReq.setSql(query);
|
||||
querySqlReq.setParams(Arrays.asList(new Param(authenticationConfig.getTokenHttpHeaderKey(), auth)));
|
||||
preparedStatementCache.put(preparedStatementHandle, querySqlReq);
|
||||
log.info("createPreparedStatement {} {} {} ", preparedStatementHandle, dataSetId, query);
|
||||
return preparedStatementHandle;
|
||||
} catch (Exception e) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private void prepared(final ActionCreatePreparedStatementRequest request, final CallContext context,
|
||||
final StreamListener<Result> listener) {
|
||||
try {
|
||||
ByteString preparedStatementHandle = addPrepared(context, request.getQuery());
|
||||
final ActionCreatePreparedStatementResult result = ActionCreatePreparedStatementResult.newBuilder()
|
||||
.setDatasetSchema(ByteString.EMPTY)
|
||||
.setParameterSchema(ByteString.empty())
|
||||
.setPreparedStatementHandle(preparedStatementHandle)
|
||||
.build();
|
||||
listener.onNext(new Result(pack(result).toByteArray()));
|
||||
} catch (Exception e) {
|
||||
listener.onError(CallStatus.INTERNAL
|
||||
.withDescription(String.format("Failed to create prepared statement: %s", e.getMessage()))
|
||||
.toRuntimeException());
|
||||
} finally {
|
||||
listener.onCompleted();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected <T extends Message> List<FlightEndpoint> determineEndpoints(T t, FlightDescriptor flightDescriptor,
|
||||
Schema schema) {
|
||||
throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
|
||||
}
|
||||
|
||||
private <T extends Message> FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor,
|
||||
final Schema schema) {
|
||||
final Ticket ticket = new Ticket(pack(request).toByteArray());
|
||||
Location listenLocation = Location.forGrpcInsecure(host, port);
|
||||
final List<FlightEndpoint> endpoints = singletonList(new FlightEndpoint(ticket, listenLocation));
|
||||
|
||||
return new FlightInfo(schema, descriptor, endpoints, -1, -1);
|
||||
}
|
||||
|
||||
private String getUserAuth(CallHeaders callHeaders) throws Exception {
|
||||
|
||||
UserReq userReq = new UserReq();
|
||||
userReq.setName(callHeaders.get(nameHeaderKey));
|
||||
userReq.setPassword(callHeaders.get(passwordHeaderKey));
|
||||
if (StringUtils.isBlank(userReq.getName()) || StringUtils.isBlank(userReq.getPassword())) {
|
||||
throw new Exception("name or password is empty");
|
||||
}
|
||||
String auth = userService.login(userReq, authenticationConfig.getTokenDefaultAppKey());
|
||||
return auth;
|
||||
}
|
||||
|
||||
private ResultSet semanticQueryRespToResultSet(SemanticQueryResp resp, Long dataSetId) throws SQLException {
|
||||
RowSetFactory factory = RowSetProvider.newFactory();
|
||||
CachedRowSet rowset = factory.createCachedRowSet();
|
||||
RowSetMetaData rowSetMetaData = new RowSetMetaDataImpl();
|
||||
int columnNum = resp.getColumns().size();
|
||||
rowSetMetaData.setColumnCount(columnNum);
|
||||
for (int i = 1; i <= columnNum; i++) {
|
||||
String columnName = resp.getColumns().get(i - 1).getNameEn();
|
||||
rowSetMetaData.setColumnName(i, columnName);
|
||||
Optional<Map<String, Object>> valOpt = resp.getResultList().stream()
|
||||
.filter(r -> r.containsKey(columnName) && Objects.nonNull(r.get(columnName))).findFirst();
|
||||
if (valOpt.isPresent()) {
|
||||
int type = FlightUtils.resolveType(valOpt.get());
|
||||
rowSetMetaData.setColumnType(i, type);
|
||||
rowSetMetaData.setNullable(i, FlightUtils.isNullable(type));
|
||||
} else {
|
||||
rowSetMetaData.setNullable(i, ResultSetMetaData.columnNullable);
|
||||
rowSetMetaData.setColumnType(i, Types.VARCHAR);
|
||||
}
|
||||
rowSetMetaData.setCatalogName(i, String.valueOf(dataSetId));
|
||||
rowSetMetaData.setSchemaName(i, dataSetIdHeaderKey);
|
||||
}
|
||||
rowset.setMetaData(rowSetMetaData);
|
||||
for (Map<String, Object> row : resp.getResultList()) {
|
||||
rowset.moveToInsertRow();
|
||||
for (int i = 1; i <= columnNum; i++) {
|
||||
String columnName = resp.getColumns().get(i - 1).getNameEn();
|
||||
if (row.containsKey(columnName)) {
|
||||
rowset.updateObject(i, row.get(columnName));
|
||||
} else {
|
||||
rowset.updateObject(i, null);
|
||||
}
|
||||
}
|
||||
rowset.insertRow();
|
||||
rowset.moveToCurrentRow();
|
||||
}
|
||||
return rowset;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package com.tencent.supersonic.headless.server.utils;
|
||||
|
||||
import java.sql.ResultSetMetaData;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Types;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* tools for arrow flight sql
|
||||
*/
|
||||
public class FlightUtils {
|
||||
|
||||
public static int resolveType(Object value) {
|
||||
if (value instanceof Long) {
|
||||
return Types.BIGINT;
|
||||
}
|
||||
if (value instanceof Integer) {
|
||||
return Types.INTEGER;
|
||||
}
|
||||
if (value instanceof Double) {
|
||||
return Types.DOUBLE;
|
||||
}
|
||||
if (value instanceof String) {
|
||||
String val = String.valueOf(value);
|
||||
if (Pattern.matches("^\\d+$", val)) {
|
||||
return Types.BIGINT;
|
||||
} else if (Pattern.matches("^\\d+\\.\\d+$", val)) {
|
||||
return Types.DECIMAL;
|
||||
} else if (Pattern.matches("^\\d{4}-\\d{2}-\\d{2}$", val)) {
|
||||
return Types.DATE;
|
||||
} else if (Pattern.matches("^\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}$", val)) {
|
||||
return Types.TIME;
|
||||
}
|
||||
}
|
||||
return Types.VARCHAR;
|
||||
}
|
||||
|
||||
public static int isNullable(int sqlType) throws SQLException {
|
||||
switch (sqlType) {
|
||||
case Types.VARCHAR:
|
||||
case Types.DECIMAL:
|
||||
return ResultSetMetaData.columnNullable;
|
||||
default:
|
||||
return ResultSetMetaData.columnNullableUnknown;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package com.tencent.supersonic.headless;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.auth.authentication.strategy.FakeUserStrategy;
|
||||
import com.tencent.supersonic.headless.server.listener.FlightSqlListener;
|
||||
import org.apache.arrow.flight.CallHeaders;
|
||||
import org.apache.arrow.flight.FlightCallHeaders;
|
||||
import org.apache.arrow.flight.FlightClient;
|
||||
import org.apache.arrow.flight.FlightInfo;
|
||||
import org.apache.arrow.flight.FlightStream;
|
||||
import org.apache.arrow.flight.HeaderCallOption;
|
||||
import org.apache.arrow.flight.Location;
|
||||
import org.apache.arrow.flight.sql.FlightSqlClient;
|
||||
import org.apache.arrow.memory.RootAllocator;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
public class FlightSqlTest extends BaseTest {
|
||||
|
||||
|
||||
@Autowired
|
||||
private FlightSqlListener flightSqlListener;
|
||||
@Autowired
|
||||
private FakeUserStrategy fakeUserStrategy;
|
||||
|
||||
@Test
|
||||
void test01() throws Exception {
|
||||
String host = flightSqlListener.getHost();
|
||||
Integer port = flightSqlListener.getPort();
|
||||
UserHolder.setStrategy(fakeUserStrategy);
|
||||
flightSqlListener.startServer();
|
||||
FlightSqlClient sqlClient = new FlightSqlClient(
|
||||
FlightClient.builder(new RootAllocator(Integer.MAX_VALUE), Location.forGrpcInsecure(host, port))
|
||||
.build());
|
||||
|
||||
CallHeaders headers = new FlightCallHeaders();
|
||||
headers.insert("dataSetId", "1");
|
||||
headers.insert("name", "admin");
|
||||
headers.insert("password", "admin");
|
||||
HeaderCallOption headerOption = new HeaderCallOption(headers);
|
||||
try (final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(
|
||||
"SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门",
|
||||
headerOption)) {
|
||||
final FlightInfo info = preparedStatement.execute();
|
||||
FlightStream stream = sqlClient.getStream(info
|
||||
.getEndpoints()
|
||||
.get(0).getTicket());
|
||||
int rowCnt = 0;
|
||||
int colCnt = 0;
|
||||
while (stream.next()) {
|
||||
if (stream.getRoot().getRowCount() > 0) {
|
||||
colCnt = stream.getRoot().getFieldVectors().size();
|
||||
rowCnt += stream.getRoot().getRowCount();
|
||||
}
|
||||
}
|
||||
assertEquals(2, colCnt);
|
||||
assertTrue(rowCnt > 0);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void test02() throws Exception {
|
||||
String host = flightSqlListener.getHost();
|
||||
Integer port = flightSqlListener.getPort();
|
||||
UserHolder.setStrategy(fakeUserStrategy);
|
||||
flightSqlListener.startServer();
|
||||
FlightSqlClient sqlClient = new FlightSqlClient(
|
||||
FlightClient.builder(new RootAllocator(Integer.MAX_VALUE), Location.forGrpcInsecure(host, port))
|
||||
.build());
|
||||
|
||||
CallHeaders headers = new FlightCallHeaders();
|
||||
headers.insert("dataSetId", "1");
|
||||
headers.insert("name", "admin");
|
||||
headers.insert("password", "admin");
|
||||
HeaderCallOption headerOption = new HeaderCallOption(headers);
|
||||
try {
|
||||
FlightInfo flightInfo = sqlClient.execute("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门",
|
||||
headerOption);
|
||||
FlightStream stream = sqlClient.getStream(flightInfo
|
||||
.getEndpoints()
|
||||
.get(0).getTicket());
|
||||
int rowCnt = 0;
|
||||
int colCnt = 0;
|
||||
while (stream.next()) {
|
||||
if (stream.getRoot().getRowCount() > 0) {
|
||||
colCnt = stream.getRoot().getFieldVectors().size();
|
||||
rowCnt += stream.getRoot().getRowCount();
|
||||
}
|
||||
}
|
||||
assertEquals(2, colCnt);
|
||||
assertTrue(rowCnt > 0);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
-- sample user
|
||||
MERGE INTO s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1);
|
||||
MERGE INTO s2_user (id, `name`, password, salt, display_name, email, is_admin) values (1, 'admin','c3VwZXJzb25pY0BiaWNvbTD12g9wGXESwL7+o7xUW90=','jGl25bVBBBW96Qi9Te4V3w==','admin','admin@xx.com', 1);
|
||||
MERGE INTO s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com');
|
||||
MERGE INTO s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com');
|
||||
MERGE INTO s2_user (id, `name`, password, display_name, email, is_admin) values (4, 'lucy','123456','lucy','lucy@xx.com', 1);
|
||||
|
||||
@@ -91,6 +91,7 @@ create table IF NOT EXISTS s2_user
|
||||
password varchar(100) null,
|
||||
email varchar(100) null,
|
||||
is_admin INT null,
|
||||
salt varchar(100) null,
|
||||
PRIMARY KEY (`id`)
|
||||
);
|
||||
COMMENT ON TABLE s2_user IS 'user information table';
|
||||
|
||||
4
pom.xml
4
pom.xml
@@ -77,6 +77,10 @@
|
||||
<postgresql.version>42.7.1</postgresql.version>
|
||||
<st.version>4.0.8</st.version>
|
||||
<duckdb_jdbc.version>0.10.0</duckdb_jdbc.version>
|
||||
<protobuf-java.version>3.23.1</protobuf-java.version>
|
||||
<flight-sql.version>15.0.2</flight-sql.version>
|
||||
<arrow-jdbc.version>15.0.2</arrow-jdbc.version>
|
||||
<flight-sql-jdbc-driver.version>15.0.2</flight-sql-jdbc-driver.version>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
|
||||
Reference in New Issue
Block a user