mirror of
https://github.com/actions/runner.git
synced 2026-04-09 04:03:17 +08:00
Compare commits
2 Commits
main
...
rentziass/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac65885854 | ||
|
|
21ba579c06 |
@@ -66,6 +66,7 @@ namespace GitHub.Runner.Worker.Dap
|
||||
|
||||
// Dev Tunnel relay host for remote debugging
|
||||
private TunnelRelayTunnelHost _tunnelRelayHost;
|
||||
private WebSocketDapBridge _webSocketBridge;
|
||||
|
||||
// Cancellation source for the connection loop, cancelled in StopAsync
|
||||
// so AcceptTcpClientAsync unblocks cleanly without relying on listener disposal.
|
||||
@@ -74,6 +75,10 @@ namespace GitHub.Runner.Worker.Dap
|
||||
// When true, skip tunnel relay startup (unit tests only)
|
||||
internal bool SkipTunnelRelay { get; set; }
|
||||
|
||||
// When true, skip the public websocket bridge and expose the raw DAP
|
||||
// listener directly on the configured tunnel port (unit tests only).
|
||||
internal bool SkipWebSocketBridge { get; set; }
|
||||
|
||||
// Synchronization for step execution
|
||||
private TaskCompletionSource<DapCommand> _commandTcs;
|
||||
private readonly object _stateLock = new object();
|
||||
@@ -108,6 +113,7 @@ namespace GitHub.Runner.Worker.Dap
|
||||
_state == DapSessionState.Running;
|
||||
|
||||
internal DapSessionState State => _state;
|
||||
internal int InternalDapPort => (_listener?.LocalEndpoint as IPEndPoint)?.Port ?? 0;
|
||||
|
||||
public override void Initialize(IHostContext hostContext)
|
||||
{
|
||||
@@ -133,9 +139,22 @@ namespace GitHub.Runner.Worker.Dap
|
||||
_jobContext = jobContext;
|
||||
_readyTcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
|
||||
_listener = new TcpListener(IPAddress.Loopback, debuggerConfig.Tunnel.Port);
|
||||
var dapPort = SkipWebSocketBridge ? debuggerConfig.Tunnel.Port : 0;
|
||||
_listener = new TcpListener(IPAddress.Loopback, dapPort);
|
||||
_listener.Start();
|
||||
Trace.Info($"DAP debugger listening on {_listener.LocalEndpoint}");
|
||||
if (SkipWebSocketBridge)
|
||||
{
|
||||
Trace.Info($"DAP debugger listening on {_listener.LocalEndpoint}");
|
||||
}
|
||||
else
|
||||
{
|
||||
Trace.Info($"Internal DAP debugger listening on {_listener.LocalEndpoint}");
|
||||
_webSocketBridge = new WebSocketDapBridge(
|
||||
HostContext.GetTrace("DapWebSocketBridge"),
|
||||
debuggerConfig.Tunnel.Port,
|
||||
InternalDapPort);
|
||||
_webSocketBridge.Start();
|
||||
}
|
||||
|
||||
// Start Dev Tunnel relay so remote clients reach the local DAP port.
|
||||
// The relay is torn down explicitly in StopAsync (after the DAP session
|
||||
@@ -274,6 +293,22 @@ namespace GitHub.Runner.Worker.Dap
|
||||
_tunnelRelayHost = null;
|
||||
}
|
||||
|
||||
if (_webSocketBridge != null)
|
||||
{
|
||||
Trace.Info("Stopping WebSocket DAP bridge");
|
||||
var disposeTask = _webSocketBridge.DisposeAsync().AsTask();
|
||||
if (await Task.WhenAny(disposeTask, Task.Delay(5_000)) != disposeTask)
|
||||
{
|
||||
Trace.Warning("WebSocket DAP bridge dispose timed out after 5s");
|
||||
}
|
||||
else
|
||||
{
|
||||
Trace.Info("WebSocket DAP bridge stopped");
|
||||
}
|
||||
|
||||
_webSocketBridge = null;
|
||||
}
|
||||
|
||||
CleanupConnection();
|
||||
|
||||
// Cancel the connection loop first so AcceptTcpClientAsync unblocks
|
||||
@@ -315,6 +350,7 @@ namespace GitHub.Runner.Worker.Dap
|
||||
_connectionLoopTask = null;
|
||||
_loopCts?.Dispose();
|
||||
_loopCts = null;
|
||||
_webSocketBridge = null;
|
||||
}
|
||||
|
||||
public async Task OnStepStartingAsync(IStep step)
|
||||
|
||||
812
src/Runner.Worker/Dap/WebSocketDapBridge.cs
Normal file
812
src/Runner.Worker/Dap/WebSocketDapBridge.cs
Normal file
@@ -0,0 +1,812 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Net.WebSockets;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using GitHub.Runner.Common;
|
||||
|
||||
namespace GitHub.Runner.Worker.Dap
|
||||
{
|
||||
internal sealed class WebSocketDapBridge : IAsyncDisposable
|
||||
{
|
||||
internal enum IncomingStreamPrefixKind
|
||||
{
|
||||
Unknown,
|
||||
HttpWebSocketUpgrade,
|
||||
PreUpgradedWebSocket,
|
||||
WebSocketReservedBits,
|
||||
Http2Preface,
|
||||
TlsClientHello,
|
||||
}
|
||||
|
||||
private const int _bufferSize = 32 * 1024;
|
||||
private const int _maxHeaderLineLength = 8 * 1024;
|
||||
private const int _defaultMaxInboundMessageSize = 10 * 1024 * 1024; // 10 MB
|
||||
private static readonly TimeSpan _keepAliveInterval = TimeSpan.FromSeconds(30);
|
||||
private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5);
|
||||
private const string _webSocketAcceptMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
|
||||
private readonly Tracing _trace;
|
||||
private readonly int _listenPort;
|
||||
private readonly int _targetPort;
|
||||
|
||||
private TcpListener _listener;
|
||||
private CancellationTokenSource _loopCts;
|
||||
private Task _acceptLoopTask;
|
||||
|
||||
// Overridable for unit tests to avoid allocating 10 MB payloads.
|
||||
internal int MaxInboundMessageSize { get; set; } = _defaultMaxInboundMessageSize;
|
||||
|
||||
public WebSocketDapBridge(Tracing trace, int listenPort, int targetPort)
|
||||
{
|
||||
_trace = trace ?? throw new ArgumentNullException(nameof(trace));
|
||||
_listenPort = listenPort;
|
||||
_targetPort = targetPort;
|
||||
}
|
||||
|
||||
public void Start()
|
||||
{
|
||||
if (_listener != null)
|
||||
{
|
||||
throw new InvalidOperationException("WebSocket DAP bridge already started.");
|
||||
}
|
||||
|
||||
_listener = new TcpListener(IPAddress.Loopback, _listenPort);
|
||||
_listener.Start();
|
||||
_loopCts = new CancellationTokenSource();
|
||||
_acceptLoopTask = AcceptLoopAsync(_loopCts.Token);
|
||||
|
||||
_trace.Info($"WebSocket DAP bridge listening on {_listener.LocalEndpoint} -> 127.0.0.1:{_targetPort}");
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
try
|
||||
{
|
||||
_loopCts?.Cancel();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// best effort during shutdown
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
_listener?.Stop();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// best effort during shutdown
|
||||
}
|
||||
|
||||
if (_acceptLoopTask != null)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _acceptLoopTask;
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// expected on shutdown
|
||||
}
|
||||
}
|
||||
|
||||
_loopCts?.Dispose();
|
||||
_loopCts = null;
|
||||
_listener = null;
|
||||
_acceptLoopTask = null;
|
||||
}
|
||||
|
||||
private async Task AcceptLoopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
TcpClient client = null;
|
||||
try
|
||||
{
|
||||
client = await _listener.AcceptTcpClientAsync(cancellationToken);
|
||||
client.NoDelay = true;
|
||||
await HandleClientAsync(client, cancellationToken);
|
||||
}
|
||||
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
break;
|
||||
}
|
||||
catch (ObjectDisposedException) when (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
break;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
client?.Dispose();
|
||||
_trace.Warning($"WebSocket DAP bridge connection error ({ex.GetType().Name})");
|
||||
_trace.Error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
_trace.Info("WebSocket DAP bridge accept loop ended");
|
||||
}
|
||||
|
||||
private async Task HandleClientAsync(TcpClient incomingClient, CancellationToken cancellationToken)
|
||||
{
|
||||
using (incomingClient)
|
||||
using (var incomingStream = incomingClient.GetStream())
|
||||
{
|
||||
_trace.Info($"WebSocket DAP bridge accepted client {incomingClient.Client.RemoteEndPoint}");
|
||||
|
||||
var webSocket = await AcceptWebSocketAsync(incomingStream, cancellationToken);
|
||||
if (webSocket == null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
using (webSocket)
|
||||
using (var dapClient = new TcpClient())
|
||||
{
|
||||
dapClient.NoDelay = true;
|
||||
await dapClient.ConnectAsync(IPAddress.Loopback, _targetPort, cancellationToken);
|
||||
|
||||
using (var dapStream = dapClient.GetStream())
|
||||
using (var sessionCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
|
||||
{
|
||||
var proxyToken = sessionCts.Token;
|
||||
var wsToTcpTask = PumpWebSocketToTcpAsync(webSocket, dapStream, proxyToken);
|
||||
var tcpToWsTask = PumpTcpToWebSocketAsync(dapStream, webSocket, proxyToken);
|
||||
|
||||
var completedTask = await Task.WhenAny(wsToTcpTask, tcpToWsTask);
|
||||
sessionCts.Cancel();
|
||||
|
||||
try
|
||||
{
|
||||
await completedTask;
|
||||
}
|
||||
catch (OperationCanceledException) when (proxyToken.IsCancellationRequested)
|
||||
{
|
||||
// expected during shutdown
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
await Task.WhenAll(wsToTcpTask, tcpToWsTask);
|
||||
}
|
||||
catch (OperationCanceledException) when (proxyToken.IsCancellationRequested)
|
||||
{
|
||||
// expected during shutdown
|
||||
}
|
||||
catch (IOException)
|
||||
{
|
||||
// peer disconnected while unwinding
|
||||
}
|
||||
catch (WebSocketException)
|
||||
{
|
||||
// peer disconnected while unwinding
|
||||
}
|
||||
}
|
||||
|
||||
await CloseWebSocketAsync(webSocket);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<WebSocket> AcceptWebSocketAsync(NetworkStream stream, CancellationToken cancellationToken)
|
||||
{
|
||||
var initialBytes = await ReadInitialBytesAsync(stream, cancellationToken);
|
||||
if (initialBytes == null || initialBytes.Length == 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var prefixKind = ClassifyIncomingStreamPrefix(initialBytes);
|
||||
if (prefixKind == IncomingStreamPrefixKind.PreUpgradedWebSocket)
|
||||
{
|
||||
_trace.Info($"Treating incoming tunnel stream as an already-upgraded websocket connection ({DescribeInitialBytes(initialBytes)})");
|
||||
return WebSocket.CreateFromStream(
|
||||
new ReplayableStream(stream, initialBytes),
|
||||
isServer: true,
|
||||
subProtocol: null,
|
||||
keepAliveInterval: _keepAliveInterval);
|
||||
}
|
||||
|
||||
if (prefixKind != IncomingStreamPrefixKind.HttpWebSocketUpgrade)
|
||||
{
|
||||
_trace.Warning($"Unsupported debugger tunnel stream prefix ({prefixKind}): {DescribeInitialBytes(initialBytes)}");
|
||||
return null;
|
||||
}
|
||||
|
||||
var handshakeStream = new ReplayableStream(stream, initialBytes);
|
||||
var requestLine = await ReadLineAsync(handshakeStream, cancellationToken);
|
||||
if (string.IsNullOrEmpty(requestLine))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||
while (true)
|
||||
{
|
||||
var line = await ReadLineAsync(handshakeStream, cancellationToken);
|
||||
if (line == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (line.Length == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
var separatorIndex = line.IndexOf(':');
|
||||
if (separatorIndex <= 0)
|
||||
{
|
||||
await WriteHttpErrorAsync(stream, HttpStatusCode.BadRequest, "Invalid HTTP header.", cancellationToken);
|
||||
return null;
|
||||
}
|
||||
|
||||
var headerName = line.Substring(0, separatorIndex).Trim();
|
||||
var headerValue = line.Substring(separatorIndex + 1).Trim();
|
||||
|
||||
if (headers.TryGetValue(headerName, out var existingValue))
|
||||
{
|
||||
headers[headerName] = $"{existingValue}, {headerValue}";
|
||||
}
|
||||
else
|
||||
{
|
||||
headers[headerName] = headerValue;
|
||||
}
|
||||
}
|
||||
|
||||
if (!IsValidWebSocketRequest(requestLine, headers))
|
||||
{
|
||||
_trace.Info($"Rejected non-websocket request: {requestLine}");
|
||||
await WriteHttpErrorAsync(stream, HttpStatusCode.BadRequest, "Expected a websocket upgrade request.", cancellationToken);
|
||||
return null;
|
||||
}
|
||||
|
||||
var webSocketKey = headers["Sec-WebSocket-Key"];
|
||||
var acceptValue = ComputeAcceptValue(webSocketKey);
|
||||
var responseBytes = Encoding.ASCII.GetBytes(
|
||||
"HTTP/1.1 101 Switching Protocols\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Upgrade: websocket\r\n" +
|
||||
$"Sec-WebSocket-Accept: {acceptValue}\r\n" +
|
||||
"\r\n");
|
||||
|
||||
await handshakeStream.WriteAsync(responseBytes, 0, responseBytes.Length, cancellationToken);
|
||||
await handshakeStream.FlushAsync(cancellationToken);
|
||||
|
||||
_trace.Info("WebSocket DAP bridge completed websocket handshake");
|
||||
return WebSocket.CreateFromStream(handshakeStream, isServer: true, subProtocol: null, keepAliveInterval: _keepAliveInterval);
|
||||
}
|
||||
|
||||
private async Task PumpWebSocketToTcpAsync(WebSocket source, NetworkStream destination, CancellationToken cancellationToken)
|
||||
{
|
||||
var buffer = new byte[_bufferSize];
|
||||
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
using (var messageStream = new MemoryStream())
|
||||
{
|
||||
WebSocketReceiveResult result;
|
||||
do
|
||||
{
|
||||
result = await source.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken);
|
||||
if (result.MessageType == WebSocketMessageType.Close)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (result.MessageType != WebSocketMessageType.Binary &&
|
||||
result.MessageType != WebSocketMessageType.Text)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.Count > 0)
|
||||
{
|
||||
if (messageStream.Length + result.Count > MaxInboundMessageSize)
|
||||
{
|
||||
_trace.Warning($"WebSocket message exceeds maximum allowed size of {MaxInboundMessageSize} bytes, closing connection");
|
||||
await source.CloseAsync(
|
||||
WebSocketCloseStatus.MessageTooBig,
|
||||
$"Message exceeds {MaxInboundMessageSize} byte limit",
|
||||
CancellationToken.None);
|
||||
return;
|
||||
}
|
||||
|
||||
messageStream.Write(buffer, 0, result.Count);
|
||||
}
|
||||
}
|
||||
while (!result.EndOfMessage);
|
||||
|
||||
if (result.MessageType != WebSocketMessageType.Binary &&
|
||||
result.MessageType != WebSocketMessageType.Text)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var messageBytes = messageStream.ToArray();
|
||||
if (messageBytes.Length == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var contentLengthHeader = Encoding.ASCII.GetBytes($"Content-Length: {messageBytes.Length}\r\n\r\n");
|
||||
await destination.WriteAsync(contentLengthHeader, 0, contentLengthHeader.Length, cancellationToken);
|
||||
await destination.WriteAsync(messageBytes, 0, messageBytes.Length, cancellationToken);
|
||||
await destination.FlushAsync(cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task PumpTcpToWebSocketAsync(NetworkStream source, WebSocket destination, CancellationToken cancellationToken)
|
||||
{
|
||||
var readBuffer = new byte[_bufferSize];
|
||||
var dapBuffer = new List<byte>();
|
||||
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
var bytesRead = await source.ReadAsync(readBuffer, 0, readBuffer.Length, cancellationToken);
|
||||
if (bytesRead == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
for (int i = 0; i < bytesRead; i++)
|
||||
{
|
||||
dapBuffer.Add(readBuffer[i]);
|
||||
}
|
||||
|
||||
while (TryParseDapMessage(dapBuffer, out var messageBody))
|
||||
{
|
||||
await destination.SendAsync(
|
||||
new ArraySegment<byte>(messageBody),
|
||||
WebSocketMessageType.Text,
|
||||
endOfMessage: true,
|
||||
cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static bool TryParseDapMessage(List<byte> buffer, out byte[] messageBody)
|
||||
{
|
||||
messageBody = null;
|
||||
|
||||
var headerEndMarker = new byte[] { (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' };
|
||||
var headerEndIndex = FindSequence(buffer, headerEndMarker);
|
||||
if (headerEndIndex == -1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var headerBytes = buffer.GetRange(0, headerEndIndex).ToArray();
|
||||
var headerText = Encoding.ASCII.GetString(headerBytes);
|
||||
|
||||
var contentLength = -1;
|
||||
foreach (var line in headerText.Split(new[] { "\r\n" }, StringSplitOptions.RemoveEmptyEntries))
|
||||
{
|
||||
if (line.StartsWith("Content-Length:", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
var valueStart = line.IndexOf(':') + 1;
|
||||
if (int.TryParse(line.Substring(valueStart).Trim(), out var parsedLength))
|
||||
{
|
||||
contentLength = parsedLength;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (contentLength < 0)
|
||||
{
|
||||
buffer.RemoveRange(0, headerEndIndex + 4);
|
||||
return false;
|
||||
}
|
||||
|
||||
var messageStart = headerEndIndex + 4;
|
||||
var messageEnd = messageStart + contentLength;
|
||||
|
||||
if (buffer.Count < messageEnd)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
messageBody = buffer.GetRange(messageStart, contentLength).ToArray();
|
||||
buffer.RemoveRange(0, messageEnd);
|
||||
return true;
|
||||
}
|
||||
|
||||
private static int FindSequence(List<byte> buffer, byte[] sequence)
|
||||
{
|
||||
if (buffer.Count < sequence.Length)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int i = 0; i <= buffer.Count - sequence.Length; i++)
|
||||
{
|
||||
var match = true;
|
||||
for (int j = 0; j < sequence.Length; j++)
|
||||
{
|
||||
if (buffer[i + j] != sequence[j])
|
||||
{
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (match)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
private static bool IsValidWebSocketRequest(string requestLine, IDictionary<string, string> headers)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(requestLine))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var requestLineParts = requestLine.Split(' ');
|
||||
if (requestLineParts.Length < 3 || !string.Equals(requestLineParts[0], "GET", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return HeaderContainsToken(headers, "Connection", "Upgrade") &&
|
||||
HeaderContainsToken(headers, "Upgrade", "websocket") &&
|
||||
headers.ContainsKey("Sec-WebSocket-Key");
|
||||
}
|
||||
|
||||
private static bool HeaderContainsToken(IDictionary<string, string> headers, string headerName, string expectedToken)
|
||||
{
|
||||
if (!headers.TryGetValue(headerName, out var headerValue) || string.IsNullOrWhiteSpace(headerValue))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return headerValue
|
||||
.Split(',')
|
||||
.Select(token => token.Trim())
|
||||
.Any(token => string.Equals(token, expectedToken, StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
private static string ComputeAcceptValue(string webSocketKey)
|
||||
{
|
||||
using (var sha1 = SHA1.Create())
|
||||
{
|
||||
var inputBytes = Encoding.ASCII.GetBytes($"{webSocketKey}{_webSocketAcceptMagic}");
|
||||
var hashBytes = sha1.ComputeHash(inputBytes);
|
||||
return Convert.ToBase64String(hashBytes);
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<string> ReadLineAsync(Stream stream, CancellationToken cancellationToken)
|
||||
{
|
||||
var lineBuilder = new StringBuilder();
|
||||
var buffer = new byte[1];
|
||||
var previousWasCarriageReturn = false;
|
||||
|
||||
while (true)
|
||||
{
|
||||
var bytesRead = await stream.ReadAsync(buffer, 0, 1, cancellationToken);
|
||||
if (bytesRead == 0)
|
||||
{
|
||||
return lineBuilder.Length > 0 ? lineBuilder.ToString() : null;
|
||||
}
|
||||
|
||||
var currentChar = (char)buffer[0];
|
||||
if (currentChar == '\n' && previousWasCarriageReturn)
|
||||
{
|
||||
if (lineBuilder.Length > 0 && lineBuilder[lineBuilder.Length - 1] == '\r')
|
||||
{
|
||||
lineBuilder.Length--;
|
||||
}
|
||||
|
||||
return lineBuilder.ToString();
|
||||
}
|
||||
|
||||
previousWasCarriageReturn = currentChar == '\r';
|
||||
lineBuilder.Append(currentChar);
|
||||
|
||||
if (lineBuilder.Length > _maxHeaderLineLength)
|
||||
{
|
||||
throw new InvalidDataException($"HTTP header line exceeds maximum length of {_maxHeaderLineLength}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<byte[]> ReadInitialBytesAsync(NetworkStream stream, CancellationToken cancellationToken)
|
||||
{
|
||||
var buffer = new byte[4];
|
||||
var totalRead = 0;
|
||||
|
||||
while (totalRead < buffer.Length)
|
||||
{
|
||||
var bytesRead = await stream.ReadAsync(buffer, totalRead, buffer.Length - totalRead, cancellationToken);
|
||||
if (bytesRead == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
totalRead += bytesRead;
|
||||
}
|
||||
|
||||
if (totalRead == 0)
|
||||
{
|
||||
return Array.Empty<byte>();
|
||||
}
|
||||
|
||||
if (totalRead == buffer.Length)
|
||||
{
|
||||
return buffer;
|
||||
}
|
||||
|
||||
var initialBytes = new byte[totalRead];
|
||||
Array.Copy(buffer, initialBytes, totalRead);
|
||||
return initialBytes;
|
||||
}
|
||||
|
||||
internal static IncomingStreamPrefixKind ClassifyIncomingStreamPrefix(byte[] initialBytes)
|
||||
{
|
||||
if (LooksLikeHttpUpgrade(initialBytes))
|
||||
{
|
||||
return IncomingStreamPrefixKind.HttpWebSocketUpgrade;
|
||||
}
|
||||
|
||||
if (LooksLikeHttp2Preface(initialBytes))
|
||||
{
|
||||
return IncomingStreamPrefixKind.Http2Preface;
|
||||
}
|
||||
|
||||
if (LooksLikeTlsClientHello(initialBytes))
|
||||
{
|
||||
return IncomingStreamPrefixKind.TlsClientHello;
|
||||
}
|
||||
|
||||
if (LooksLikeWebSocketFramePrefix(initialBytes, requireReservedBitsClear: false))
|
||||
{
|
||||
return HasReservedBitsSet(initialBytes[0])
|
||||
? IncomingStreamPrefixKind.WebSocketReservedBits
|
||||
: IncomingStreamPrefixKind.PreUpgradedWebSocket;
|
||||
}
|
||||
|
||||
return IncomingStreamPrefixKind.Unknown;
|
||||
}
|
||||
|
||||
internal static string DescribeInitialBytes(byte[] initialBytes)
|
||||
{
|
||||
if (initialBytes == null || initialBytes.Length == 0)
|
||||
{
|
||||
return "no bytes read";
|
||||
}
|
||||
|
||||
var hex = BitConverter.ToString(initialBytes);
|
||||
var ascii = new string(initialBytes.Select(value => value >= 32 && value <= 126 ? (char)value : '.').ToArray());
|
||||
return $"hex={hex}, ascii=\"{ascii}\"";
|
||||
}
|
||||
|
||||
private static bool LooksLikeHttpUpgrade(byte[] initialBytes)
|
||||
{
|
||||
if (initialBytes == null || initialBytes.Length < 4)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return initialBytes[0] == (byte)'G' &&
|
||||
initialBytes[1] == (byte)'E' &&
|
||||
initialBytes[2] == (byte)'T' &&
|
||||
initialBytes[3] == (byte)' ';
|
||||
}
|
||||
|
||||
private static bool LooksLikeHttp2Preface(byte[] initialBytes)
|
||||
{
|
||||
if (initialBytes == null || initialBytes.Length < 4)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return initialBytes[0] == (byte)'P' &&
|
||||
initialBytes[1] == (byte)'R' &&
|
||||
initialBytes[2] == (byte)'I' &&
|
||||
initialBytes[3] == (byte)' ';
|
||||
}
|
||||
|
||||
private static bool LooksLikeTlsClientHello(byte[] initialBytes)
|
||||
{
|
||||
if (initialBytes == null || initialBytes.Length < 3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return initialBytes[0] == 0x16 &&
|
||||
initialBytes[1] == 0x03 &&
|
||||
initialBytes[2] >= 0x00 &&
|
||||
initialBytes[2] <= 0x04;
|
||||
}
|
||||
|
||||
private static bool LooksLikeWebSocketFramePrefix(byte[] initialBytes, bool requireReservedBitsClear)
|
||||
{
|
||||
if (initialBytes == null || initialBytes.Length < 2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var firstByte = initialBytes[0];
|
||||
var secondByte = initialBytes[1];
|
||||
var opcode = firstByte & 0x0F;
|
||||
var isMasked = (secondByte & 0x80) != 0;
|
||||
|
||||
if (!isMasked || !IsSupportedWebSocketOpcode(opcode))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return !requireReservedBitsClear || !HasReservedBitsSet(firstByte);
|
||||
}
|
||||
|
||||
private static bool HasReservedBitsSet(byte firstByte)
|
||||
{
|
||||
return (firstByte & 0x70) != 0;
|
||||
}
|
||||
|
||||
private static bool IsSupportedWebSocketOpcode(int opcode)
|
||||
{
|
||||
switch (opcode)
|
||||
{
|
||||
case 0x0:
|
||||
case 0x1:
|
||||
case 0x2:
|
||||
case 0x8:
|
||||
case 0x9:
|
||||
case 0xA:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task WriteHttpErrorAsync(
|
||||
NetworkStream stream,
|
||||
HttpStatusCode statusCode,
|
||||
string message,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var bodyBytes = Encoding.UTF8.GetBytes(message);
|
||||
var responseBytes = Encoding.ASCII.GetBytes(
|
||||
$"HTTP/1.1 {(int)statusCode} {statusCode}\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"Content-Type: text/plain; charset=utf-8\r\n" +
|
||||
$"Content-Length: {bodyBytes.Length}\r\n" +
|
||||
"Sec-WebSocket-Version: 13\r\n" +
|
||||
"\r\n");
|
||||
|
||||
await stream.WriteAsync(responseBytes, 0, responseBytes.Length, cancellationToken);
|
||||
await stream.WriteAsync(bodyBytes, 0, bodyBytes.Length, cancellationToken);
|
||||
await stream.FlushAsync(cancellationToken);
|
||||
}
|
||||
|
||||
private static async Task CloseWebSocketAsync(WebSocket webSocket)
|
||||
{
|
||||
if (webSocket == null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (webSocket.State != WebSocketState.Open &&
|
||||
webSocket.State != WebSocketState.CloseReceived)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
using var cts = new CancellationTokenSource(_closeTimeout);
|
||||
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cts.Token);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Graceful close timed out, abort the connection.
|
||||
webSocket.Abort();
|
||||
}
|
||||
catch (WebSocketException)
|
||||
{
|
||||
// Peer already disconnected.
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class ReplayableStream : Stream
|
||||
{
|
||||
private readonly Stream _innerStream;
|
||||
private readonly byte[] _prefixBytes;
|
||||
private int _prefixOffset;
|
||||
|
||||
public ReplayableStream(Stream innerStream, byte[] prefixBytes)
|
||||
{
|
||||
_innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream));
|
||||
_prefixBytes = prefixBytes ?? Array.Empty<byte>();
|
||||
}
|
||||
|
||||
public override bool CanRead => _innerStream.CanRead;
|
||||
public override bool CanSeek => false;
|
||||
public override bool CanWrite => _innerStream.CanWrite;
|
||||
public override long Length => throw new NotSupportedException();
|
||||
|
||||
public override long Position
|
||||
{
|
||||
get => throw new NotSupportedException();
|
||||
set => throw new NotSupportedException();
|
||||
}
|
||||
|
||||
public override void Flush() => _innerStream.Flush();
|
||||
|
||||
public override Task FlushAsync(CancellationToken cancellationToken) => _innerStream.FlushAsync(cancellationToken);
|
||||
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
if (TryReadPrefix(buffer, offset, count, out var bytesRead))
|
||||
{
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
return _innerStream.Read(buffer, offset, count);
|
||||
}
|
||||
|
||||
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
if (TryReadPrefix(buffer, offset, count, out var bytesRead))
|
||||
{
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
return await _innerStream.ReadAsync(buffer, offset, count, cancellationToken);
|
||||
}
|
||||
|
||||
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (_prefixOffset < _prefixBytes.Length)
|
||||
{
|
||||
var bytesToCopy = Math.Min(buffer.Length, _prefixBytes.Length - _prefixOffset);
|
||||
new ReadOnlySpan<byte>(_prefixBytes, _prefixOffset, bytesToCopy).CopyTo(buffer.Span);
|
||||
_prefixOffset += bytesToCopy;
|
||||
return bytesToCopy;
|
||||
}
|
||||
|
||||
return await _innerStream.ReadAsync(buffer, cancellationToken);
|
||||
}
|
||||
|
||||
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
|
||||
|
||||
public override void SetLength(long value) => throw new NotSupportedException();
|
||||
|
||||
public override void Write(byte[] buffer, int offset, int count) => _innerStream.Write(buffer, offset, count);
|
||||
|
||||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
|
||||
_innerStream.WriteAsync(buffer, offset, count, cancellationToken);
|
||||
|
||||
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) =>
|
||||
_innerStream.WriteAsync(buffer, cancellationToken);
|
||||
|
||||
private bool TryReadPrefix(byte[] buffer, int offset, int count, out int bytesRead)
|
||||
{
|
||||
if (_prefixOffset >= _prefixBytes.Length)
|
||||
{
|
||||
bytesRead = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
bytesRead = Math.Min(count, _prefixBytes.Length - _prefixOffset);
|
||||
Array.Copy(_prefixBytes, _prefixOffset, buffer, offset, bytesRead);
|
||||
_prefixOffset += bytesRead;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
using System;
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Net.WebSockets;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
@@ -20,12 +21,13 @@ namespace GitHub.Runner.Common.Tests.Worker
|
||||
private const string TunnelConnectTimeoutVariable = "ACTIONS_RUNNER_DAP_TUNNEL_CONNECT_TIMEOUT_SECONDS";
|
||||
private DapDebugger _debugger;
|
||||
|
||||
private TestHostContext CreateTestContext([CallerMemberName] string testName = "")
|
||||
private TestHostContext CreateTestContext(bool enableWebSocketBridge = false, [CallerMemberName] string testName = "")
|
||||
{
|
||||
var hc = new TestHostContext(this, testName);
|
||||
_debugger = new DapDebugger();
|
||||
_debugger.Initialize(hc);
|
||||
_debugger.SkipTunnelRelay = true;
|
||||
_debugger.SkipWebSocketBridge = !enableWebSocketBridge;
|
||||
return hc;
|
||||
}
|
||||
|
||||
@@ -71,6 +73,13 @@ namespace GitHub.Runner.Common.Tests.Worker
|
||||
return client;
|
||||
}
|
||||
|
||||
private static async Task<ClientWebSocket> ConnectWebSocketClientAsync(int port)
|
||||
{
|
||||
var client = new ClientWebSocket();
|
||||
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
|
||||
return client;
|
||||
}
|
||||
|
||||
private static async Task SendRequestAsync(NetworkStream stream, Request request)
|
||||
{
|
||||
var json = JsonConvert.SerializeObject(request);
|
||||
@@ -83,6 +92,14 @@ namespace GitHub.Runner.Common.Tests.Worker
|
||||
await stream.FlushAsync();
|
||||
}
|
||||
|
||||
private static async Task SendRequestAsync(WebSocket client, Request request)
|
||||
{
|
||||
var json = JsonConvert.SerializeObject(request);
|
||||
var body = Encoding.UTF8.GetBytes(json);
|
||||
|
||||
await client.SendAsync(new ArraySegment<byte>(body), WebSocketMessageType.Text, endOfMessage: true, CancellationToken.None);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads a single DAP-framed message from a stream with a timeout.
|
||||
/// Parses the Content-Length header, reads exactly that many bytes,
|
||||
@@ -141,6 +158,52 @@ namespace GitHub.Runner.Common.Tests.Worker
|
||||
return Encoding.UTF8.GetString(body);
|
||||
}
|
||||
|
||||
private static async Task<string> ReadWebSocketDataUntilAsync(WebSocket client, TimeSpan timeout, params string[] expectedFragments)
|
||||
{
|
||||
using var cts = new CancellationTokenSource(timeout);
|
||||
var buffer = new byte[4096];
|
||||
var allMessages = new StringBuilder();
|
||||
|
||||
while (true)
|
||||
{
|
||||
using var messageStream = new MemoryStream();
|
||||
WebSocketReceiveResult result;
|
||||
do
|
||||
{
|
||||
result = await client.ReceiveAsync(new ArraySegment<byte>(buffer), cts.Token);
|
||||
if (result.MessageType == WebSocketMessageType.Close)
|
||||
{
|
||||
throw new EndOfStreamException("WebSocket closed before expected DAP messages were received.");
|
||||
}
|
||||
|
||||
if (result.Count > 0)
|
||||
{
|
||||
messageStream.Write(buffer, 0, result.Count);
|
||||
}
|
||||
}
|
||||
while (!result.EndOfMessage);
|
||||
|
||||
var messageText = Encoding.UTF8.GetString(messageStream.ToArray());
|
||||
allMessages.Append(messageText);
|
||||
|
||||
var text = allMessages.ToString();
|
||||
var containsAllFragments = true;
|
||||
foreach (var fragment in expectedFragments)
|
||||
{
|
||||
if (!text.Contains(fragment, StringComparison.Ordinal))
|
||||
{
|
||||
containsAllFragments = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (containsAllFragments)
|
||||
{
|
||||
return text;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static Mock<IExecutionContext> CreateJobContextWithTunnel(CancellationToken cancellationToken, ushort port, string jobName = null)
|
||||
{
|
||||
var tunnel = new GitHub.DistributedTask.Pipelines.DebuggerTunnelInfo
|
||||
@@ -208,6 +271,82 @@ namespace GitHub.Runner.Common.Tests.Worker
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Level", "L0")]
|
||||
[Trait("Category", "Worker")]
|
||||
public async Task StartAsyncWithWebSocketBridgeAcceptsInitializeOverWebSocket()
|
||||
{
|
||||
using (CreateTestContext(enableWebSocketBridge: true))
|
||||
{
|
||||
var port = GetFreePort();
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
|
||||
var jobContext = CreateJobContextWithTunnel(cts.Token, port);
|
||||
await _debugger.StartAsync(jobContext.Object);
|
||||
|
||||
Assert.NotEqual(0, _debugger.InternalDapPort);
|
||||
Assert.NotEqual(port, _debugger.InternalDapPort);
|
||||
|
||||
using var client = await ConnectWebSocketClientAsync(port);
|
||||
await SendRequestAsync(client, new Request
|
||||
{
|
||||
Seq = 1,
|
||||
Type = "request",
|
||||
Command = "initialize"
|
||||
});
|
||||
|
||||
var response = await ReadWebSocketDataUntilAsync(
|
||||
client,
|
||||
TimeSpan.FromSeconds(5),
|
||||
"\"type\":\"response\"",
|
||||
"\"command\":\"initialize\"",
|
||||
"\"event\":\"initialized\"");
|
||||
|
||||
Assert.Contains("\"success\":true", response);
|
||||
await _debugger.StopAsync();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Level", "L0")]
|
||||
[Trait("Category", "Worker")]
|
||||
public async Task StartAsyncWithWebSocketBridgeAcceptsPreUpgradedWebSocketStream()
|
||||
{
|
||||
using (CreateTestContext(enableWebSocketBridge: true))
|
||||
{
|
||||
var port = GetFreePort();
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
|
||||
var jobContext = CreateJobContextWithTunnel(cts.Token, port);
|
||||
await _debugger.StartAsync(jobContext.Object);
|
||||
|
||||
Assert.NotEqual(0, _debugger.InternalDapPort);
|
||||
Assert.NotEqual(port, _debugger.InternalDapPort);
|
||||
|
||||
using var tcpClient = await ConnectClientAsync(port);
|
||||
using var webSocket = WebSocket.CreateFromStream(
|
||||
tcpClient.GetStream(),
|
||||
isServer: false,
|
||||
subProtocol: null,
|
||||
keepAliveInterval: TimeSpan.FromSeconds(30));
|
||||
|
||||
await SendRequestAsync(webSocket, new Request
|
||||
{
|
||||
Seq = 1,
|
||||
Type = "request",
|
||||
Command = "initialize"
|
||||
});
|
||||
|
||||
var response = await ReadWebSocketDataUntilAsync(
|
||||
webSocket,
|
||||
TimeSpan.FromSeconds(5),
|
||||
"\"type\":\"response\"",
|
||||
"\"command\":\"initialize\"",
|
||||
"\"event\":\"initialized\"");
|
||||
|
||||
Assert.Contains("\"success\":true", response);
|
||||
await _debugger.StopAsync();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Level", "L0")]
|
||||
[Trait("Category", "Worker")]
|
||||
|
||||
245
src/Test/L0/Worker/WebSocketDapBridgeL0.cs
Normal file
245
src/Test/L0/Worker/WebSocketDapBridgeL0.cs
Normal file
@@ -0,0 +1,245 @@
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Net.WebSockets;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using GitHub.Runner.Common;
|
||||
using GitHub.Runner.Worker.Dap;
|
||||
using Xunit;
|
||||
|
||||
namespace GitHub.Runner.Common.Tests.Worker
|
||||
{
|
||||
public sealed class WebSocketDapBridgeL0
|
||||
{
|
||||
private TestHostContext CreateTestContext([CallerMemberName] string testName = "")
|
||||
{
|
||||
return new TestHostContext(this, testName);
|
||||
}
|
||||
|
||||
private static ushort GetFreePort()
|
||||
{
|
||||
using var listener = new TcpListener(IPAddress.Loopback, 0);
|
||||
listener.Start();
|
||||
return (ushort)((IPEndPoint)listener.LocalEndpoint).Port;
|
||||
}
|
||||
|
||||
private static async Task<byte[]> ReadWebSocketMessageAsync(ClientWebSocket client, TimeSpan timeout)
|
||||
{
|
||||
using var cts = new CancellationTokenSource(timeout);
|
||||
using var buffer = new MemoryStream();
|
||||
var receiveBuffer = new byte[1024];
|
||||
|
||||
while (true)
|
||||
{
|
||||
var result = await client.ReceiveAsync(new ArraySegment<byte>(receiveBuffer), cts.Token);
|
||||
if (result.MessageType == WebSocketMessageType.Close)
|
||||
{
|
||||
throw new EndOfStreamException("WebSocket closed unexpectedly.");
|
||||
}
|
||||
|
||||
if (result.Count > 0)
|
||||
{
|
||||
buffer.Write(receiveBuffer, 0, result.Count);
|
||||
}
|
||||
|
||||
if (result.EndOfMessage)
|
||||
{
|
||||
return buffer.ToArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Level", "L0")]
|
||||
[Trait("Category", "Worker")]
|
||||
public async Task BridgeForwardsWebSocketFramesToTcpAndBack()
|
||||
{
|
||||
using var hc = CreateTestContext();
|
||||
using var targetListener = new TcpListener(IPAddress.Loopback, 0);
|
||||
targetListener.Start();
|
||||
|
||||
var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port;
|
||||
var bridgePort = GetFreePort();
|
||||
|
||||
await using var bridge = new WebSocketDapBridge(hc.GetTrace("DapWebSocketBridge"), bridgePort, targetPort);
|
||||
bridge.Start();
|
||||
|
||||
var echoTask = Task.Run(async () =>
|
||||
{
|
||||
using var targetClient = await targetListener.AcceptTcpClientAsync();
|
||||
using var stream = targetClient.GetStream();
|
||||
|
||||
var headerBuilder = new StringBuilder();
|
||||
var buffer = new byte[1];
|
||||
var contentLength = -1;
|
||||
|
||||
while (true)
|
||||
{
|
||||
var bytesRead = await stream.ReadAsync(buffer, 0, 1);
|
||||
if (bytesRead == 0) break;
|
||||
|
||||
headerBuilder.Append((char)buffer[0]);
|
||||
var headers = headerBuilder.ToString();
|
||||
if (headers.EndsWith("\r\n\r\n", StringComparison.Ordinal))
|
||||
{
|
||||
foreach (var line in headers.Split(new[] { "\r\n" }, StringSplitOptions.RemoveEmptyEntries))
|
||||
{
|
||||
if (line.StartsWith("Content-Length: ", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
contentLength = int.Parse(line.Substring("Content-Length: ".Length).Trim());
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
var body = new byte[contentLength];
|
||||
var totalRead = 0;
|
||||
while (totalRead < contentLength)
|
||||
{
|
||||
var bytesRead = await stream.ReadAsync(body, totalRead, contentLength - totalRead);
|
||||
if (bytesRead == 0) break;
|
||||
totalRead += bytesRead;
|
||||
}
|
||||
|
||||
var header = $"Content-Length: {body.Length}\r\n\r\n";
|
||||
var headerBytes = Encoding.ASCII.GetBytes(header);
|
||||
await stream.WriteAsync(headerBytes, 0, headerBytes.Length);
|
||||
await stream.WriteAsync(body, 0, body.Length);
|
||||
await stream.FlushAsync();
|
||||
});
|
||||
|
||||
using var client = new ClientWebSocket();
|
||||
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{bridgePort}/"), CancellationToken.None);
|
||||
|
||||
var dapMessage = "{\"type\":\"request\",\"seq\":1,\"command\":\"initialize\"}";
|
||||
var payload = Encoding.UTF8.GetBytes(dapMessage);
|
||||
await client.SendAsync(new ArraySegment<byte>(payload), WebSocketMessageType.Text, endOfMessage: true, CancellationToken.None);
|
||||
|
||||
var echoed = await ReadWebSocketMessageAsync(client, TimeSpan.FromSeconds(5));
|
||||
Assert.Equal(payload, echoed);
|
||||
|
||||
await echoTask;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Level", "L0")]
|
||||
[Trait("Category", "Worker")]
|
||||
public async Task BridgeRejectsNonWebSocketRequests()
|
||||
{
|
||||
using var hc = CreateTestContext();
|
||||
var bridgePort = GetFreePort();
|
||||
|
||||
await using var bridge = new WebSocketDapBridge(hc.GetTrace("DapWebSocketBridge"), bridgePort, GetFreePort());
|
||||
bridge.Start();
|
||||
|
||||
using var client = new TcpClient();
|
||||
await client.ConnectAsync(IPAddress.Loopback, bridgePort);
|
||||
using var stream = client.GetStream();
|
||||
|
||||
var request = Encoding.ASCII.GetBytes(
|
||||
"GET / HTTP/1.1\r\n" +
|
||||
"Host: localhost\r\n" +
|
||||
"\r\n");
|
||||
await stream.WriteAsync(request, 0, request.Length);
|
||||
await stream.FlushAsync();
|
||||
|
||||
// Read until the server closes the connection (Connection: close).
|
||||
// A single ReadAsync may return a partial response on some platforms.
|
||||
using var ms = new MemoryStream();
|
||||
var responseBuffer = new byte[1024];
|
||||
int bytesRead;
|
||||
while ((bytesRead = await stream.ReadAsync(responseBuffer, 0, responseBuffer.Length)) > 0)
|
||||
{
|
||||
ms.Write(responseBuffer, 0, bytesRead);
|
||||
}
|
||||
|
||||
var response = Encoding.ASCII.GetString(ms.ToArray());
|
||||
|
||||
Assert.Contains("400 BadRequest", response);
|
||||
Assert.Contains("Expected a websocket upgrade request.", response);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[Trait("Level", "L0")]
|
||||
[Trait("Category", "Worker")]
|
||||
[InlineData(new byte[] { (byte)'G', (byte)'E', (byte)'T', (byte)' ' }, 1)]
|
||||
[InlineData(new byte[] { 0x81, 0x85, 0x00, 0x00 }, 2)]
|
||||
[InlineData(new byte[] { 0xC1, 0x85, 0x00, 0x00 }, 3)]
|
||||
[InlineData(new byte[] { (byte)'P', (byte)'R', (byte)'I', (byte)' ' }, 4)]
|
||||
[InlineData(new byte[] { 0x16, 0x03, 0x03, 0x01 }, 5)]
|
||||
[InlineData(new byte[] { (byte)'B', (byte)'A', (byte)'D', (byte)'!' }, 0)]
|
||||
public void ClassifyIncomingStreamPrefixDetectsExpectedProtocols(byte[] initialBytes, int expectedKind)
|
||||
{
|
||||
var actualKind = WebSocketDapBridge.ClassifyIncomingStreamPrefix(initialBytes);
|
||||
Assert.Equal((WebSocketDapBridge.IncomingStreamPrefixKind)expectedKind, actualKind);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Level", "L0")]
|
||||
[Trait("Category", "Worker")]
|
||||
public async Task BridgeRejectsOversizedWebSocketMessage()
|
||||
{
|
||||
using var hc = CreateTestContext();
|
||||
using var targetListener = new TcpListener(IPAddress.Loopback, 0);
|
||||
targetListener.Start();
|
||||
|
||||
var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port;
|
||||
var bridgePort = GetFreePort();
|
||||
|
||||
await using var bridge = new WebSocketDapBridge(hc.GetTrace("DapWebSocketBridge"), bridgePort, targetPort);
|
||||
bridge.MaxInboundMessageSize = 64; // artificially small limit for testing
|
||||
bridge.Start();
|
||||
|
||||
using var client = new ClientWebSocket();
|
||||
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{bridgePort}/"), CancellationToken.None);
|
||||
|
||||
// Send a message that exceeds the 64-byte limit
|
||||
var oversizedPayload = new byte[128];
|
||||
Array.Fill(oversizedPayload, (byte)'X');
|
||||
await client.SendAsync(
|
||||
new ArraySegment<byte>(oversizedPayload),
|
||||
WebSocketMessageType.Text,
|
||||
endOfMessage: true,
|
||||
CancellationToken.None);
|
||||
|
||||
// The bridge should close the connection with MessageTooBig
|
||||
var receiveBuffer = new byte[256];
|
||||
var result = await client.ReceiveAsync(
|
||||
new ArraySegment<byte>(receiveBuffer),
|
||||
new CancellationTokenSource(TimeSpan.FromSeconds(5)).Token);
|
||||
|
||||
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
|
||||
Assert.Equal(WebSocketCloseStatus.MessageTooBig, client.CloseStatus);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Level", "L0")]
|
||||
[Trait("Category", "Worker")]
|
||||
public async Task BridgeDisposeCompletesWhenPeerDoesNotCloseGracefully()
|
||||
{
|
||||
using var hc = CreateTestContext();
|
||||
using var targetListener = new TcpListener(IPAddress.Loopback, 0);
|
||||
targetListener.Start();
|
||||
|
||||
var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port;
|
||||
var bridgePort = GetFreePort();
|
||||
|
||||
var bridge = new WebSocketDapBridge(hc.GetTrace("DapWebSocketBridge"), bridgePort, targetPort);
|
||||
bridge.Start();
|
||||
|
||||
// Connect a raw TCP client but never perform WebSocket close handshake
|
||||
using var rawClient = new TcpClient();
|
||||
await rawClient.ConnectAsync(IPAddress.Loopback, bridgePort);
|
||||
|
||||
// Dispose should complete within a bounded time, not hang
|
||||
var disposeTask = bridge.DisposeAsync().AsTask();
|
||||
var completed = await Task.WhenAny(disposeTask, Task.Delay(TimeSpan.FromSeconds(15)));
|
||||
Assert.True(completed == disposeTask, "Bridge dispose should complete within the timeout, not hang on a non-cooperative peer");
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user