Compare commits

..

2 Commits

Author SHA1 Message Date
Francesco Renzi
ac65885854 fix windows 2026-04-08 09:41:31 -07:00
Francesco Renzi
21ba579c06 Add WS bridge over DAP TCP server 2026-04-08 08:53:41 -07:00
4 changed files with 1236 additions and 4 deletions

View File

@@ -66,6 +66,7 @@ namespace GitHub.Runner.Worker.Dap
// Dev Tunnel relay host for remote debugging
private TunnelRelayTunnelHost _tunnelRelayHost;
private WebSocketDapBridge _webSocketBridge;
// Cancellation source for the connection loop, cancelled in StopAsync
// so AcceptTcpClientAsync unblocks cleanly without relying on listener disposal.
@@ -74,6 +75,10 @@ namespace GitHub.Runner.Worker.Dap
// When true, skip tunnel relay startup (unit tests only)
internal bool SkipTunnelRelay { get; set; }
// When true, skip the public websocket bridge and expose the raw DAP
// listener directly on the configured tunnel port (unit tests only).
internal bool SkipWebSocketBridge { get; set; }
// Synchronization for step execution
private TaskCompletionSource<DapCommand> _commandTcs;
private readonly object _stateLock = new object();
@@ -108,6 +113,7 @@ namespace GitHub.Runner.Worker.Dap
_state == DapSessionState.Running;
internal DapSessionState State => _state;
internal int InternalDapPort => (_listener?.LocalEndpoint as IPEndPoint)?.Port ?? 0;
public override void Initialize(IHostContext hostContext)
{
@@ -133,9 +139,22 @@ namespace GitHub.Runner.Worker.Dap
_jobContext = jobContext;
_readyTcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
_listener = new TcpListener(IPAddress.Loopback, debuggerConfig.Tunnel.Port);
var dapPort = SkipWebSocketBridge ? debuggerConfig.Tunnel.Port : 0;
_listener = new TcpListener(IPAddress.Loopback, dapPort);
_listener.Start();
Trace.Info($"DAP debugger listening on {_listener.LocalEndpoint}");
if (SkipWebSocketBridge)
{
Trace.Info($"DAP debugger listening on {_listener.LocalEndpoint}");
}
else
{
Trace.Info($"Internal DAP debugger listening on {_listener.LocalEndpoint}");
_webSocketBridge = new WebSocketDapBridge(
HostContext.GetTrace("DapWebSocketBridge"),
debuggerConfig.Tunnel.Port,
InternalDapPort);
_webSocketBridge.Start();
}
// Start Dev Tunnel relay so remote clients reach the local DAP port.
// The relay is torn down explicitly in StopAsync (after the DAP session
@@ -274,6 +293,22 @@ namespace GitHub.Runner.Worker.Dap
_tunnelRelayHost = null;
}
if (_webSocketBridge != null)
{
Trace.Info("Stopping WebSocket DAP bridge");
var disposeTask = _webSocketBridge.DisposeAsync().AsTask();
if (await Task.WhenAny(disposeTask, Task.Delay(5_000)) != disposeTask)
{
Trace.Warning("WebSocket DAP bridge dispose timed out after 5s");
}
else
{
Trace.Info("WebSocket DAP bridge stopped");
}
_webSocketBridge = null;
}
CleanupConnection();
// Cancel the connection loop first so AcceptTcpClientAsync unblocks
@@ -315,6 +350,7 @@ namespace GitHub.Runner.Worker.Dap
_connectionLoopTask = null;
_loopCts?.Dispose();
_loopCts = null;
_webSocketBridge = null;
}
public async Task OnStepStartingAsync(IStep step)

View File

@@ -0,0 +1,812 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Net.WebSockets;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using GitHub.Runner.Common;
namespace GitHub.Runner.Worker.Dap
{
internal sealed class WebSocketDapBridge : IAsyncDisposable
{
internal enum IncomingStreamPrefixKind
{
Unknown,
HttpWebSocketUpgrade,
PreUpgradedWebSocket,
WebSocketReservedBits,
Http2Preface,
TlsClientHello,
}
private const int _bufferSize = 32 * 1024;
private const int _maxHeaderLineLength = 8 * 1024;
private const int _defaultMaxInboundMessageSize = 10 * 1024 * 1024; // 10 MB
private static readonly TimeSpan _keepAliveInterval = TimeSpan.FromSeconds(30);
private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5);
private const string _webSocketAcceptMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
private readonly Tracing _trace;
private readonly int _listenPort;
private readonly int _targetPort;
private TcpListener _listener;
private CancellationTokenSource _loopCts;
private Task _acceptLoopTask;
// Overridable for unit tests to avoid allocating 10 MB payloads.
internal int MaxInboundMessageSize { get; set; } = _defaultMaxInboundMessageSize;
public WebSocketDapBridge(Tracing trace, int listenPort, int targetPort)
{
_trace = trace ?? throw new ArgumentNullException(nameof(trace));
_listenPort = listenPort;
_targetPort = targetPort;
}
public void Start()
{
if (_listener != null)
{
throw new InvalidOperationException("WebSocket DAP bridge already started.");
}
_listener = new TcpListener(IPAddress.Loopback, _listenPort);
_listener.Start();
_loopCts = new CancellationTokenSource();
_acceptLoopTask = AcceptLoopAsync(_loopCts.Token);
_trace.Info($"WebSocket DAP bridge listening on {_listener.LocalEndpoint} -> 127.0.0.1:{_targetPort}");
}
public async ValueTask DisposeAsync()
{
try
{
_loopCts?.Cancel();
}
catch
{
// best effort during shutdown
}
try
{
_listener?.Stop();
}
catch
{
// best effort during shutdown
}
if (_acceptLoopTask != null)
{
try
{
await _acceptLoopTask;
}
catch (OperationCanceledException)
{
// expected on shutdown
}
}
_loopCts?.Dispose();
_loopCts = null;
_listener = null;
_acceptLoopTask = null;
}
private async Task AcceptLoopAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
TcpClient client = null;
try
{
client = await _listener.AcceptTcpClientAsync(cancellationToken);
client.NoDelay = true;
await HandleClientAsync(client, cancellationToken);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
break;
}
catch (ObjectDisposedException) when (cancellationToken.IsCancellationRequested)
{
break;
}
catch (Exception ex)
{
client?.Dispose();
_trace.Warning($"WebSocket DAP bridge connection error ({ex.GetType().Name})");
_trace.Error(ex);
}
}
_trace.Info("WebSocket DAP bridge accept loop ended");
}
private async Task HandleClientAsync(TcpClient incomingClient, CancellationToken cancellationToken)
{
using (incomingClient)
using (var incomingStream = incomingClient.GetStream())
{
_trace.Info($"WebSocket DAP bridge accepted client {incomingClient.Client.RemoteEndPoint}");
var webSocket = await AcceptWebSocketAsync(incomingStream, cancellationToken);
if (webSocket == null)
{
return;
}
using (webSocket)
using (var dapClient = new TcpClient())
{
dapClient.NoDelay = true;
await dapClient.ConnectAsync(IPAddress.Loopback, _targetPort, cancellationToken);
using (var dapStream = dapClient.GetStream())
using (var sessionCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
{
var proxyToken = sessionCts.Token;
var wsToTcpTask = PumpWebSocketToTcpAsync(webSocket, dapStream, proxyToken);
var tcpToWsTask = PumpTcpToWebSocketAsync(dapStream, webSocket, proxyToken);
var completedTask = await Task.WhenAny(wsToTcpTask, tcpToWsTask);
sessionCts.Cancel();
try
{
await completedTask;
}
catch (OperationCanceledException) when (proxyToken.IsCancellationRequested)
{
// expected during shutdown
}
try
{
await Task.WhenAll(wsToTcpTask, tcpToWsTask);
}
catch (OperationCanceledException) when (proxyToken.IsCancellationRequested)
{
// expected during shutdown
}
catch (IOException)
{
// peer disconnected while unwinding
}
catch (WebSocketException)
{
// peer disconnected while unwinding
}
}
await CloseWebSocketAsync(webSocket);
}
}
}
private async Task<WebSocket> AcceptWebSocketAsync(NetworkStream stream, CancellationToken cancellationToken)
{
var initialBytes = await ReadInitialBytesAsync(stream, cancellationToken);
if (initialBytes == null || initialBytes.Length == 0)
{
return null;
}
var prefixKind = ClassifyIncomingStreamPrefix(initialBytes);
if (prefixKind == IncomingStreamPrefixKind.PreUpgradedWebSocket)
{
_trace.Info($"Treating incoming tunnel stream as an already-upgraded websocket connection ({DescribeInitialBytes(initialBytes)})");
return WebSocket.CreateFromStream(
new ReplayableStream(stream, initialBytes),
isServer: true,
subProtocol: null,
keepAliveInterval: _keepAliveInterval);
}
if (prefixKind != IncomingStreamPrefixKind.HttpWebSocketUpgrade)
{
_trace.Warning($"Unsupported debugger tunnel stream prefix ({prefixKind}): {DescribeInitialBytes(initialBytes)}");
return null;
}
var handshakeStream = new ReplayableStream(stream, initialBytes);
var requestLine = await ReadLineAsync(handshakeStream, cancellationToken);
if (string.IsNullOrEmpty(requestLine))
{
return null;
}
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
while (true)
{
var line = await ReadLineAsync(handshakeStream, cancellationToken);
if (line == null)
{
return null;
}
if (line.Length == 0)
{
break;
}
var separatorIndex = line.IndexOf(':');
if (separatorIndex <= 0)
{
await WriteHttpErrorAsync(stream, HttpStatusCode.BadRequest, "Invalid HTTP header.", cancellationToken);
return null;
}
var headerName = line.Substring(0, separatorIndex).Trim();
var headerValue = line.Substring(separatorIndex + 1).Trim();
if (headers.TryGetValue(headerName, out var existingValue))
{
headers[headerName] = $"{existingValue}, {headerValue}";
}
else
{
headers[headerName] = headerValue;
}
}
if (!IsValidWebSocketRequest(requestLine, headers))
{
_trace.Info($"Rejected non-websocket request: {requestLine}");
await WriteHttpErrorAsync(stream, HttpStatusCode.BadRequest, "Expected a websocket upgrade request.", cancellationToken);
return null;
}
var webSocketKey = headers["Sec-WebSocket-Key"];
var acceptValue = ComputeAcceptValue(webSocketKey);
var responseBytes = Encoding.ASCII.GetBytes(
"HTTP/1.1 101 Switching Protocols\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: websocket\r\n" +
$"Sec-WebSocket-Accept: {acceptValue}\r\n" +
"\r\n");
await handshakeStream.WriteAsync(responseBytes, 0, responseBytes.Length, cancellationToken);
await handshakeStream.FlushAsync(cancellationToken);
_trace.Info("WebSocket DAP bridge completed websocket handshake");
return WebSocket.CreateFromStream(handshakeStream, isServer: true, subProtocol: null, keepAliveInterval: _keepAliveInterval);
}
private async Task PumpWebSocketToTcpAsync(WebSocket source, NetworkStream destination, CancellationToken cancellationToken)
{
var buffer = new byte[_bufferSize];
while (!cancellationToken.IsCancellationRequested)
{
using (var messageStream = new MemoryStream())
{
WebSocketReceiveResult result;
do
{
result = await source.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken);
if (result.MessageType == WebSocketMessageType.Close)
{
return;
}
if (result.MessageType != WebSocketMessageType.Binary &&
result.MessageType != WebSocketMessageType.Text)
{
break;
}
if (result.Count > 0)
{
if (messageStream.Length + result.Count > MaxInboundMessageSize)
{
_trace.Warning($"WebSocket message exceeds maximum allowed size of {MaxInboundMessageSize} bytes, closing connection");
await source.CloseAsync(
WebSocketCloseStatus.MessageTooBig,
$"Message exceeds {MaxInboundMessageSize} byte limit",
CancellationToken.None);
return;
}
messageStream.Write(buffer, 0, result.Count);
}
}
while (!result.EndOfMessage);
if (result.MessageType != WebSocketMessageType.Binary &&
result.MessageType != WebSocketMessageType.Text)
{
continue;
}
var messageBytes = messageStream.ToArray();
if (messageBytes.Length == 0)
{
continue;
}
var contentLengthHeader = Encoding.ASCII.GetBytes($"Content-Length: {messageBytes.Length}\r\n\r\n");
await destination.WriteAsync(contentLengthHeader, 0, contentLengthHeader.Length, cancellationToken);
await destination.WriteAsync(messageBytes, 0, messageBytes.Length, cancellationToken);
await destination.FlushAsync(cancellationToken);
}
}
}
private static async Task PumpTcpToWebSocketAsync(NetworkStream source, WebSocket destination, CancellationToken cancellationToken)
{
var readBuffer = new byte[_bufferSize];
var dapBuffer = new List<byte>();
while (!cancellationToken.IsCancellationRequested)
{
var bytesRead = await source.ReadAsync(readBuffer, 0, readBuffer.Length, cancellationToken);
if (bytesRead == 0)
{
break;
}
for (int i = 0; i < bytesRead; i++)
{
dapBuffer.Add(readBuffer[i]);
}
while (TryParseDapMessage(dapBuffer, out var messageBody))
{
await destination.SendAsync(
new ArraySegment<byte>(messageBody),
WebSocketMessageType.Text,
endOfMessage: true,
cancellationToken);
}
}
}
private static bool TryParseDapMessage(List<byte> buffer, out byte[] messageBody)
{
messageBody = null;
var headerEndMarker = new byte[] { (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' };
var headerEndIndex = FindSequence(buffer, headerEndMarker);
if (headerEndIndex == -1)
{
return false;
}
var headerBytes = buffer.GetRange(0, headerEndIndex).ToArray();
var headerText = Encoding.ASCII.GetString(headerBytes);
var contentLength = -1;
foreach (var line in headerText.Split(new[] { "\r\n" }, StringSplitOptions.RemoveEmptyEntries))
{
if (line.StartsWith("Content-Length:", StringComparison.OrdinalIgnoreCase))
{
var valueStart = line.IndexOf(':') + 1;
if (int.TryParse(line.Substring(valueStart).Trim(), out var parsedLength))
{
contentLength = parsedLength;
break;
}
}
}
if (contentLength < 0)
{
buffer.RemoveRange(0, headerEndIndex + 4);
return false;
}
var messageStart = headerEndIndex + 4;
var messageEnd = messageStart + contentLength;
if (buffer.Count < messageEnd)
{
return false;
}
messageBody = buffer.GetRange(messageStart, contentLength).ToArray();
buffer.RemoveRange(0, messageEnd);
return true;
}
private static int FindSequence(List<byte> buffer, byte[] sequence)
{
if (buffer.Count < sequence.Length)
{
return -1;
}
for (int i = 0; i <= buffer.Count - sequence.Length; i++)
{
var match = true;
for (int j = 0; j < sequence.Length; j++)
{
if (buffer[i + j] != sequence[j])
{
match = false;
break;
}
}
if (match)
{
return i;
}
}
return -1;
}
private static bool IsValidWebSocketRequest(string requestLine, IDictionary<string, string> headers)
{
if (string.IsNullOrWhiteSpace(requestLine))
{
return false;
}
var requestLineParts = requestLine.Split(' ');
if (requestLineParts.Length < 3 || !string.Equals(requestLineParts[0], "GET", StringComparison.OrdinalIgnoreCase))
{
return false;
}
return HeaderContainsToken(headers, "Connection", "Upgrade") &&
HeaderContainsToken(headers, "Upgrade", "websocket") &&
headers.ContainsKey("Sec-WebSocket-Key");
}
private static bool HeaderContainsToken(IDictionary<string, string> headers, string headerName, string expectedToken)
{
if (!headers.TryGetValue(headerName, out var headerValue) || string.IsNullOrWhiteSpace(headerValue))
{
return false;
}
return headerValue
.Split(',')
.Select(token => token.Trim())
.Any(token => string.Equals(token, expectedToken, StringComparison.OrdinalIgnoreCase));
}
private static string ComputeAcceptValue(string webSocketKey)
{
using (var sha1 = SHA1.Create())
{
var inputBytes = Encoding.ASCII.GetBytes($"{webSocketKey}{_webSocketAcceptMagic}");
var hashBytes = sha1.ComputeHash(inputBytes);
return Convert.ToBase64String(hashBytes);
}
}
private static async Task<string> ReadLineAsync(Stream stream, CancellationToken cancellationToken)
{
var lineBuilder = new StringBuilder();
var buffer = new byte[1];
var previousWasCarriageReturn = false;
while (true)
{
var bytesRead = await stream.ReadAsync(buffer, 0, 1, cancellationToken);
if (bytesRead == 0)
{
return lineBuilder.Length > 0 ? lineBuilder.ToString() : null;
}
var currentChar = (char)buffer[0];
if (currentChar == '\n' && previousWasCarriageReturn)
{
if (lineBuilder.Length > 0 && lineBuilder[lineBuilder.Length - 1] == '\r')
{
lineBuilder.Length--;
}
return lineBuilder.ToString();
}
previousWasCarriageReturn = currentChar == '\r';
lineBuilder.Append(currentChar);
if (lineBuilder.Length > _maxHeaderLineLength)
{
throw new InvalidDataException($"HTTP header line exceeds maximum length of {_maxHeaderLineLength}");
}
}
}
private static async Task<byte[]> ReadInitialBytesAsync(NetworkStream stream, CancellationToken cancellationToken)
{
var buffer = new byte[4];
var totalRead = 0;
while (totalRead < buffer.Length)
{
var bytesRead = await stream.ReadAsync(buffer, totalRead, buffer.Length - totalRead, cancellationToken);
if (bytesRead == 0)
{
break;
}
totalRead += bytesRead;
}
if (totalRead == 0)
{
return Array.Empty<byte>();
}
if (totalRead == buffer.Length)
{
return buffer;
}
var initialBytes = new byte[totalRead];
Array.Copy(buffer, initialBytes, totalRead);
return initialBytes;
}
internal static IncomingStreamPrefixKind ClassifyIncomingStreamPrefix(byte[] initialBytes)
{
if (LooksLikeHttpUpgrade(initialBytes))
{
return IncomingStreamPrefixKind.HttpWebSocketUpgrade;
}
if (LooksLikeHttp2Preface(initialBytes))
{
return IncomingStreamPrefixKind.Http2Preface;
}
if (LooksLikeTlsClientHello(initialBytes))
{
return IncomingStreamPrefixKind.TlsClientHello;
}
if (LooksLikeWebSocketFramePrefix(initialBytes, requireReservedBitsClear: false))
{
return HasReservedBitsSet(initialBytes[0])
? IncomingStreamPrefixKind.WebSocketReservedBits
: IncomingStreamPrefixKind.PreUpgradedWebSocket;
}
return IncomingStreamPrefixKind.Unknown;
}
internal static string DescribeInitialBytes(byte[] initialBytes)
{
if (initialBytes == null || initialBytes.Length == 0)
{
return "no bytes read";
}
var hex = BitConverter.ToString(initialBytes);
var ascii = new string(initialBytes.Select(value => value >= 32 && value <= 126 ? (char)value : '.').ToArray());
return $"hex={hex}, ascii=\"{ascii}\"";
}
private static bool LooksLikeHttpUpgrade(byte[] initialBytes)
{
if (initialBytes == null || initialBytes.Length < 4)
{
return false;
}
return initialBytes[0] == (byte)'G' &&
initialBytes[1] == (byte)'E' &&
initialBytes[2] == (byte)'T' &&
initialBytes[3] == (byte)' ';
}
private static bool LooksLikeHttp2Preface(byte[] initialBytes)
{
if (initialBytes == null || initialBytes.Length < 4)
{
return false;
}
return initialBytes[0] == (byte)'P' &&
initialBytes[1] == (byte)'R' &&
initialBytes[2] == (byte)'I' &&
initialBytes[3] == (byte)' ';
}
private static bool LooksLikeTlsClientHello(byte[] initialBytes)
{
if (initialBytes == null || initialBytes.Length < 3)
{
return false;
}
return initialBytes[0] == 0x16 &&
initialBytes[1] == 0x03 &&
initialBytes[2] >= 0x00 &&
initialBytes[2] <= 0x04;
}
private static bool LooksLikeWebSocketFramePrefix(byte[] initialBytes, bool requireReservedBitsClear)
{
if (initialBytes == null || initialBytes.Length < 2)
{
return false;
}
var firstByte = initialBytes[0];
var secondByte = initialBytes[1];
var opcode = firstByte & 0x0F;
var isMasked = (secondByte & 0x80) != 0;
if (!isMasked || !IsSupportedWebSocketOpcode(opcode))
{
return false;
}
return !requireReservedBitsClear || !HasReservedBitsSet(firstByte);
}
private static bool HasReservedBitsSet(byte firstByte)
{
return (firstByte & 0x70) != 0;
}
private static bool IsSupportedWebSocketOpcode(int opcode)
{
switch (opcode)
{
case 0x0:
case 0x1:
case 0x2:
case 0x8:
case 0x9:
case 0xA:
return true;
default:
return false;
}
}
private static async Task WriteHttpErrorAsync(
NetworkStream stream,
HttpStatusCode statusCode,
string message,
CancellationToken cancellationToken)
{
var bodyBytes = Encoding.UTF8.GetBytes(message);
var responseBytes = Encoding.ASCII.GetBytes(
$"HTTP/1.1 {(int)statusCode} {statusCode}\r\n" +
"Connection: close\r\n" +
"Content-Type: text/plain; charset=utf-8\r\n" +
$"Content-Length: {bodyBytes.Length}\r\n" +
"Sec-WebSocket-Version: 13\r\n" +
"\r\n");
await stream.WriteAsync(responseBytes, 0, responseBytes.Length, cancellationToken);
await stream.WriteAsync(bodyBytes, 0, bodyBytes.Length, cancellationToken);
await stream.FlushAsync(cancellationToken);
}
private static async Task CloseWebSocketAsync(WebSocket webSocket)
{
if (webSocket == null)
{
return;
}
if (webSocket.State != WebSocketState.Open &&
webSocket.State != WebSocketState.CloseReceived)
{
return;
}
try
{
using var cts = new CancellationTokenSource(_closeTimeout);
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cts.Token);
}
catch (OperationCanceledException)
{
// Graceful close timed out, abort the connection.
webSocket.Abort();
}
catch (WebSocketException)
{
// Peer already disconnected.
}
}
private sealed class ReplayableStream : Stream
{
private readonly Stream _innerStream;
private readonly byte[] _prefixBytes;
private int _prefixOffset;
public ReplayableStream(Stream innerStream, byte[] prefixBytes)
{
_innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream));
_prefixBytes = prefixBytes ?? Array.Empty<byte>();
}
public override bool CanRead => _innerStream.CanRead;
public override bool CanSeek => false;
public override bool CanWrite => _innerStream.CanWrite;
public override long Length => throw new NotSupportedException();
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
public override void Flush() => _innerStream.Flush();
public override Task FlushAsync(CancellationToken cancellationToken) => _innerStream.FlushAsync(cancellationToken);
public override int Read(byte[] buffer, int offset, int count)
{
if (TryReadPrefix(buffer, offset, count, out var bytesRead))
{
return bytesRead;
}
return _innerStream.Read(buffer, offset, count);
}
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (TryReadPrefix(buffer, offset, count, out var bytesRead))
{
return bytesRead;
}
return await _innerStream.ReadAsync(buffer, offset, count, cancellationToken);
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (_prefixOffset < _prefixBytes.Length)
{
var bytesToCopy = Math.Min(buffer.Length, _prefixBytes.Length - _prefixOffset);
new ReadOnlySpan<byte>(_prefixBytes, _prefixOffset, bytesToCopy).CopyTo(buffer.Span);
_prefixOffset += bytesToCopy;
return bytesToCopy;
}
return await _innerStream.ReadAsync(buffer, cancellationToken);
}
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
public override void Write(byte[] buffer, int offset, int count) => _innerStream.Write(buffer, offset, count);
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
_innerStream.WriteAsync(buffer, offset, count, cancellationToken);
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) =>
_innerStream.WriteAsync(buffer, cancellationToken);
private bool TryReadPrefix(byte[] buffer, int offset, int count, out int bytesRead)
{
if (_prefixOffset >= _prefixBytes.Length)
{
bytesRead = 0;
return false;
}
bytesRead = Math.Min(count, _prefixBytes.Length - _prefixOffset);
Array.Copy(_prefixBytes, _prefixOffset, buffer, offset, bytesRead);
_prefixOffset += bytesRead;
return true;
}
}
}
}

View File

@@ -1,7 +1,8 @@
using System;
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Net.WebSockets;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
@@ -20,12 +21,13 @@ namespace GitHub.Runner.Common.Tests.Worker
private const string TunnelConnectTimeoutVariable = "ACTIONS_RUNNER_DAP_TUNNEL_CONNECT_TIMEOUT_SECONDS";
private DapDebugger _debugger;
private TestHostContext CreateTestContext([CallerMemberName] string testName = "")
private TestHostContext CreateTestContext(bool enableWebSocketBridge = false, [CallerMemberName] string testName = "")
{
var hc = new TestHostContext(this, testName);
_debugger = new DapDebugger();
_debugger.Initialize(hc);
_debugger.SkipTunnelRelay = true;
_debugger.SkipWebSocketBridge = !enableWebSocketBridge;
return hc;
}
@@ -71,6 +73,13 @@ namespace GitHub.Runner.Common.Tests.Worker
return client;
}
private static async Task<ClientWebSocket> ConnectWebSocketClientAsync(int port)
{
var client = new ClientWebSocket();
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
return client;
}
private static async Task SendRequestAsync(NetworkStream stream, Request request)
{
var json = JsonConvert.SerializeObject(request);
@@ -83,6 +92,14 @@ namespace GitHub.Runner.Common.Tests.Worker
await stream.FlushAsync();
}
private static async Task SendRequestAsync(WebSocket client, Request request)
{
var json = JsonConvert.SerializeObject(request);
var body = Encoding.UTF8.GetBytes(json);
await client.SendAsync(new ArraySegment<byte>(body), WebSocketMessageType.Text, endOfMessage: true, CancellationToken.None);
}
/// <summary>
/// Reads a single DAP-framed message from a stream with a timeout.
/// Parses the Content-Length header, reads exactly that many bytes,
@@ -141,6 +158,52 @@ namespace GitHub.Runner.Common.Tests.Worker
return Encoding.UTF8.GetString(body);
}
private static async Task<string> ReadWebSocketDataUntilAsync(WebSocket client, TimeSpan timeout, params string[] expectedFragments)
{
using var cts = new CancellationTokenSource(timeout);
var buffer = new byte[4096];
var allMessages = new StringBuilder();
while (true)
{
using var messageStream = new MemoryStream();
WebSocketReceiveResult result;
do
{
result = await client.ReceiveAsync(new ArraySegment<byte>(buffer), cts.Token);
if (result.MessageType == WebSocketMessageType.Close)
{
throw new EndOfStreamException("WebSocket closed before expected DAP messages were received.");
}
if (result.Count > 0)
{
messageStream.Write(buffer, 0, result.Count);
}
}
while (!result.EndOfMessage);
var messageText = Encoding.UTF8.GetString(messageStream.ToArray());
allMessages.Append(messageText);
var text = allMessages.ToString();
var containsAllFragments = true;
foreach (var fragment in expectedFragments)
{
if (!text.Contains(fragment, StringComparison.Ordinal))
{
containsAllFragments = false;
break;
}
}
if (containsAllFragments)
{
return text;
}
}
}
private static Mock<IExecutionContext> CreateJobContextWithTunnel(CancellationToken cancellationToken, ushort port, string jobName = null)
{
var tunnel = new GitHub.DistributedTask.Pipelines.DebuggerTunnelInfo
@@ -208,6 +271,82 @@ namespace GitHub.Runner.Common.Tests.Worker
}
}
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Worker")]
public async Task StartAsyncWithWebSocketBridgeAcceptsInitializeOverWebSocket()
{
using (CreateTestContext(enableWebSocketBridge: true))
{
var port = GetFreePort();
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
var jobContext = CreateJobContextWithTunnel(cts.Token, port);
await _debugger.StartAsync(jobContext.Object);
Assert.NotEqual(0, _debugger.InternalDapPort);
Assert.NotEqual(port, _debugger.InternalDapPort);
using var client = await ConnectWebSocketClientAsync(port);
await SendRequestAsync(client, new Request
{
Seq = 1,
Type = "request",
Command = "initialize"
});
var response = await ReadWebSocketDataUntilAsync(
client,
TimeSpan.FromSeconds(5),
"\"type\":\"response\"",
"\"command\":\"initialize\"",
"\"event\":\"initialized\"");
Assert.Contains("\"success\":true", response);
await _debugger.StopAsync();
}
}
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Worker")]
public async Task StartAsyncWithWebSocketBridgeAcceptsPreUpgradedWebSocketStream()
{
using (CreateTestContext(enableWebSocketBridge: true))
{
var port = GetFreePort();
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
var jobContext = CreateJobContextWithTunnel(cts.Token, port);
await _debugger.StartAsync(jobContext.Object);
Assert.NotEqual(0, _debugger.InternalDapPort);
Assert.NotEqual(port, _debugger.InternalDapPort);
using var tcpClient = await ConnectClientAsync(port);
using var webSocket = WebSocket.CreateFromStream(
tcpClient.GetStream(),
isServer: false,
subProtocol: null,
keepAliveInterval: TimeSpan.FromSeconds(30));
await SendRequestAsync(webSocket, new Request
{
Seq = 1,
Type = "request",
Command = "initialize"
});
var response = await ReadWebSocketDataUntilAsync(
webSocket,
TimeSpan.FromSeconds(5),
"\"type\":\"response\"",
"\"command\":\"initialize\"",
"\"event\":\"initialized\"");
Assert.Contains("\"success\":true", response);
await _debugger.StopAsync();
}
}
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Worker")]

View File

@@ -0,0 +1,245 @@
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Net.WebSockets;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using GitHub.Runner.Common;
using GitHub.Runner.Worker.Dap;
using Xunit;
namespace GitHub.Runner.Common.Tests.Worker
{
public sealed class WebSocketDapBridgeL0
{
private TestHostContext CreateTestContext([CallerMemberName] string testName = "")
{
return new TestHostContext(this, testName);
}
private static ushort GetFreePort()
{
using var listener = new TcpListener(IPAddress.Loopback, 0);
listener.Start();
return (ushort)((IPEndPoint)listener.LocalEndpoint).Port;
}
private static async Task<byte[]> ReadWebSocketMessageAsync(ClientWebSocket client, TimeSpan timeout)
{
using var cts = new CancellationTokenSource(timeout);
using var buffer = new MemoryStream();
var receiveBuffer = new byte[1024];
while (true)
{
var result = await client.ReceiveAsync(new ArraySegment<byte>(receiveBuffer), cts.Token);
if (result.MessageType == WebSocketMessageType.Close)
{
throw new EndOfStreamException("WebSocket closed unexpectedly.");
}
if (result.Count > 0)
{
buffer.Write(receiveBuffer, 0, result.Count);
}
if (result.EndOfMessage)
{
return buffer.ToArray();
}
}
}
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Worker")]
public async Task BridgeForwardsWebSocketFramesToTcpAndBack()
{
using var hc = CreateTestContext();
using var targetListener = new TcpListener(IPAddress.Loopback, 0);
targetListener.Start();
var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port;
var bridgePort = GetFreePort();
await using var bridge = new WebSocketDapBridge(hc.GetTrace("DapWebSocketBridge"), bridgePort, targetPort);
bridge.Start();
var echoTask = Task.Run(async () =>
{
using var targetClient = await targetListener.AcceptTcpClientAsync();
using var stream = targetClient.GetStream();
var headerBuilder = new StringBuilder();
var buffer = new byte[1];
var contentLength = -1;
while (true)
{
var bytesRead = await stream.ReadAsync(buffer, 0, 1);
if (bytesRead == 0) break;
headerBuilder.Append((char)buffer[0]);
var headers = headerBuilder.ToString();
if (headers.EndsWith("\r\n\r\n", StringComparison.Ordinal))
{
foreach (var line in headers.Split(new[] { "\r\n" }, StringSplitOptions.RemoveEmptyEntries))
{
if (line.StartsWith("Content-Length: ", StringComparison.OrdinalIgnoreCase))
{
contentLength = int.Parse(line.Substring("Content-Length: ".Length).Trim());
}
}
break;
}
}
var body = new byte[contentLength];
var totalRead = 0;
while (totalRead < contentLength)
{
var bytesRead = await stream.ReadAsync(body, totalRead, contentLength - totalRead);
if (bytesRead == 0) break;
totalRead += bytesRead;
}
var header = $"Content-Length: {body.Length}\r\n\r\n";
var headerBytes = Encoding.ASCII.GetBytes(header);
await stream.WriteAsync(headerBytes, 0, headerBytes.Length);
await stream.WriteAsync(body, 0, body.Length);
await stream.FlushAsync();
});
using var client = new ClientWebSocket();
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{bridgePort}/"), CancellationToken.None);
var dapMessage = "{\"type\":\"request\",\"seq\":1,\"command\":\"initialize\"}";
var payload = Encoding.UTF8.GetBytes(dapMessage);
await client.SendAsync(new ArraySegment<byte>(payload), WebSocketMessageType.Text, endOfMessage: true, CancellationToken.None);
var echoed = await ReadWebSocketMessageAsync(client, TimeSpan.FromSeconds(5));
Assert.Equal(payload, echoed);
await echoTask;
}
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Worker")]
public async Task BridgeRejectsNonWebSocketRequests()
{
using var hc = CreateTestContext();
var bridgePort = GetFreePort();
await using var bridge = new WebSocketDapBridge(hc.GetTrace("DapWebSocketBridge"), bridgePort, GetFreePort());
bridge.Start();
using var client = new TcpClient();
await client.ConnectAsync(IPAddress.Loopback, bridgePort);
using var stream = client.GetStream();
var request = Encoding.ASCII.GetBytes(
"GET / HTTP/1.1\r\n" +
"Host: localhost\r\n" +
"\r\n");
await stream.WriteAsync(request, 0, request.Length);
await stream.FlushAsync();
// Read until the server closes the connection (Connection: close).
// A single ReadAsync may return a partial response on some platforms.
using var ms = new MemoryStream();
var responseBuffer = new byte[1024];
int bytesRead;
while ((bytesRead = await stream.ReadAsync(responseBuffer, 0, responseBuffer.Length)) > 0)
{
ms.Write(responseBuffer, 0, bytesRead);
}
var response = Encoding.ASCII.GetString(ms.ToArray());
Assert.Contains("400 BadRequest", response);
Assert.Contains("Expected a websocket upgrade request.", response);
}
[Theory]
[Trait("Level", "L0")]
[Trait("Category", "Worker")]
[InlineData(new byte[] { (byte)'G', (byte)'E', (byte)'T', (byte)' ' }, 1)]
[InlineData(new byte[] { 0x81, 0x85, 0x00, 0x00 }, 2)]
[InlineData(new byte[] { 0xC1, 0x85, 0x00, 0x00 }, 3)]
[InlineData(new byte[] { (byte)'P', (byte)'R', (byte)'I', (byte)' ' }, 4)]
[InlineData(new byte[] { 0x16, 0x03, 0x03, 0x01 }, 5)]
[InlineData(new byte[] { (byte)'B', (byte)'A', (byte)'D', (byte)'!' }, 0)]
public void ClassifyIncomingStreamPrefixDetectsExpectedProtocols(byte[] initialBytes, int expectedKind)
{
var actualKind = WebSocketDapBridge.ClassifyIncomingStreamPrefix(initialBytes);
Assert.Equal((WebSocketDapBridge.IncomingStreamPrefixKind)expectedKind, actualKind);
}
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Worker")]
public async Task BridgeRejectsOversizedWebSocketMessage()
{
using var hc = CreateTestContext();
using var targetListener = new TcpListener(IPAddress.Loopback, 0);
targetListener.Start();
var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port;
var bridgePort = GetFreePort();
await using var bridge = new WebSocketDapBridge(hc.GetTrace("DapWebSocketBridge"), bridgePort, targetPort);
bridge.MaxInboundMessageSize = 64; // artificially small limit for testing
bridge.Start();
using var client = new ClientWebSocket();
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{bridgePort}/"), CancellationToken.None);
// Send a message that exceeds the 64-byte limit
var oversizedPayload = new byte[128];
Array.Fill(oversizedPayload, (byte)'X');
await client.SendAsync(
new ArraySegment<byte>(oversizedPayload),
WebSocketMessageType.Text,
endOfMessage: true,
CancellationToken.None);
// The bridge should close the connection with MessageTooBig
var receiveBuffer = new byte[256];
var result = await client.ReceiveAsync(
new ArraySegment<byte>(receiveBuffer),
new CancellationTokenSource(TimeSpan.FromSeconds(5)).Token);
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
Assert.Equal(WebSocketCloseStatus.MessageTooBig, client.CloseStatus);
}
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Worker")]
public async Task BridgeDisposeCompletesWhenPeerDoesNotCloseGracefully()
{
using var hc = CreateTestContext();
using var targetListener = new TcpListener(IPAddress.Loopback, 0);
targetListener.Start();
var targetPort = ((IPEndPoint)targetListener.LocalEndpoint).Port;
var bridgePort = GetFreePort();
var bridge = new WebSocketDapBridge(hc.GetTrace("DapWebSocketBridge"), bridgePort, targetPort);
bridge.Start();
// Connect a raw TCP client but never perform WebSocket close handshake
using var rawClient = new TcpClient();
await rawClient.ConnectAsync(IPAddress.Loopback, bridgePort);
// Dispose should complete within a bounded time, not hang
var disposeTask = bridge.DisposeAsync().AsTask();
var completed = await Task.WhenAny(disposeTask, Task.Delay(TimeSpan.FromSeconds(15)));
Assert.True(completed == disposeTask, "Bridge dispose should complete within the timeout, not hang on a non-cooperative peer");
}
}
}