Enable auth migration based on config refresh. (#3786)

This commit is contained in:
Tingluo Huang
2025-04-02 23:24:57 -04:00
committed by GitHub
parent a0a0a76378
commit c1095ae2d1
9 changed files with 94 additions and 32 deletions

View File

@@ -37,6 +37,7 @@ namespace GitHub.Runner.Common
public async Task ConnectAsync(Uri serverUri, VssCredentials credentials)
{
Trace.Entering();
_brokerUri = serverUri;
_connection = VssUtil.CreateRawConnection(serverUri, credentials);

View File

@@ -245,7 +245,6 @@ namespace GitHub.Runner.Listener
// Decrypt the message body if the session is using encryption
message = DecryptMessage(message);
if (message != null && message.MessageType == BrokerMigrationMessage.MessageType)
{
var migrationMessage = JsonUtility.FromString<BrokerMigrationMessage>(message.Body);
@@ -306,6 +305,10 @@ namespace GitHub.Runner.Listener
Trace.Error("Catch exception during get next message.");
Trace.Error(ex);
// clear out potential message for broker migration,
// in case the exception is thrown from get message from broker-listener.
message = null;
// don't retry if SkipSessionRecover = true, DT service will delete agent session to stop agent from taking more jobs.
if (ex is TaskAgentSessionExpiredException && !_settings.SkipSessionRecover && (await CreateSessionAsync(token) == CreateSessionResult.Success))
{

View File

@@ -31,6 +31,7 @@ namespace GitHub.Runner.Listener
private ITerminal _term;
private bool _inConfigStage;
private ManualResetEvent _completedCommand = new(false);
private IRunnerServer _runnerServer;
// <summary>
// Helps avoid excessive calls to Run Service when encountering non-retriable errors from /acquirejob.
@@ -51,6 +52,7 @@ namespace GitHub.Runner.Listener
base.Initialize(hostContext);
_term = HostContext.GetService<ITerminal>();
_acquireJobThrottler = HostContext.CreateService<IErrorThrottler>();
_runnerServer = HostContext.GetService<IRunnerServer>();
}
public async Task<int> ExecuteCommand(CommandSettings command)

View File

@@ -211,7 +211,17 @@ namespace GitHub.Runner.Listener
// save the refreshed runner credentials as a separate file
_store.SaveMigratedCredential(refreshedCredConfig);
await ReportTelemetryAsync("Runner credentials updated successfully.");
if (refreshedCredConfig.Data.ContainsKey("authorizationUrlV2"))
{
HostContext.EnableAuthMigration("Credential file updated");
await ReportTelemetryAsync("Runner credentials updated successfully. Auth migration is enabled.");
}
else
{
HostContext.DeferAuthMigration(TimeSpan.FromDays(365), "Credential file does not contain authorizationUrlV2");
await ReportTelemetryAsync("Runner credentials updated successfully. Auth migration is disabled.");
}
}
private async Task<bool> VerifyRunnerQualifiedId(string runnerQualifiedId)

View File

@@ -18,8 +18,6 @@ namespace GitHub.Runner.Common.Tests.Listener
private readonly Mock<IBrokerServer> _brokerServer;
private readonly Mock<IRunnerServer> _runnerServer;
private readonly Mock<ICredentialManager> _credMgr;
private Mock<IConfigurationStore> _store;
public BrokerMessageListenerL0()
{
@@ -27,7 +25,6 @@ namespace GitHub.Runner.Common.Tests.Listener
_config = new Mock<IConfigurationManager>();
_config.Setup(x => x.LoadSettings()).Returns(_settings);
_credMgr = new Mock<ICredentialManager>();
_store = new Mock<IConfigurationStore>();
_brokerServer = new Mock<IBrokerServer>();
_runnerServer = new Mock<IRunnerServer>();
}
@@ -35,7 +32,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void CreatesSession()
public async Task CreatesSession()
{
using (TestHostContext tc = CreateTestContext())
using (var tokenSource = new CancellationTokenSource())
@@ -51,8 +48,6 @@ namespace GitHub.Runner.Common.Tests.Listener
.Returns(Task.FromResult(expectedSession));
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
// Act.
BrokerMessageListener listener = new();
@@ -75,7 +70,6 @@ namespace GitHub.Runner.Common.Tests.Listener
TestHostContext tc = new(this, testName);
tc.SetSingleton<IConfigurationManager>(_config.Object);
tc.SetSingleton<ICredentialManager>(_credMgr.Object);
tc.SetSingleton<IConfigurationStore>(_store.Object);
tc.SetSingleton<IBrokerServer>(_brokerServer.Object);
tc.SetSingleton<IRunnerServer>(_runnerServer.Object);
return tc;

View File

@@ -51,7 +51,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void CreatesSession()
public async Task CreatesSession()
{
using (TestHostContext tc = CreateTestContext())
using (var tokenSource = new CancellationTokenSource())
@@ -95,7 +95,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void DeleteSession()
public async Task DeleteSession()
{
using (TestHostContext tc = CreateTestContext())
using (var tokenSource = new CancellationTokenSource())
@@ -142,7 +142,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void GetNextMessage()
public async Task GetNextMessage()
{
using (TestHostContext tc = CreateTestContext())
using (var tokenSource = new CancellationTokenSource())
@@ -223,7 +223,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void GetNextMessageWithBrokerMigration()
public async Task GetNextMessageWithBrokerMigration()
{
using (TestHostContext tc = CreateTestContext())
using (var tokenSource = new CancellationTokenSource())
@@ -329,7 +329,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void CreateSessionWithOriginalCredential()
public async Task CreateSessionWithOriginalCredential()
{
using (TestHostContext tc = CreateTestContext())
using (var tokenSource = new CancellationTokenSource())
@@ -374,7 +374,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void SkipDeleteSession_WhenGetNextMessageGetTaskAgentAccessTokenExpiredException()
public async Task SkipDeleteSession_WhenGetNextMessageGetTaskAgentAccessTokenExpiredException()
{
using (TestHostContext tc = CreateTestContext())
using (var tokenSource = new CancellationTokenSource())

View File

@@ -210,9 +210,9 @@ namespace GitHub.Runner.Tests.Listener
var encodedConfig = Convert.ToBase64String(Encoding.UTF8.GetBytes(StringUtil.ConvertToJson(credData)));
_runnerServer.Setup(x => x.RefreshRunnerConfigAsync(It.IsAny<int>(), It.Is<string>(s => s == "credentials"), It.IsAny<string>(), It.IsAny<CancellationToken>())).ReturnsAsync(encodedConfig);
var _runnerConfigUpdater = new RunnerConfigUpdater();
_runnerConfigUpdater.Initialize(hc);
hc.EnableAuthMigration("L0Test");
var validRunnerQualifiedId = "valid/runner/qualifiedid/1";
var configType = "credentials";
@@ -226,6 +226,7 @@ namespace GitHub.Runner.Tests.Listener
_runnerServer.Verify(x => x.RefreshRunnerConfigAsync(1, "credentials", It.IsAny<string>(), It.IsAny<CancellationToken>()), Times.Once);
_runnerServer.Verify(x => x.UpdateAgentUpdateStateAsync(It.IsAny<int>(), It.IsAny<ulong>(), It.IsAny<string>(), It.Is<string>(s => s.Contains("Runner credentials updated successfully")), It.IsAny<CancellationToken>()), Times.Once);
_configurationStore.Verify(x => x.SaveMigratedCredential(It.IsAny<CredentialData>()), Times.Once);
Assert.False(hc.AllowAuthMigration);
}
}
@@ -306,7 +307,7 @@ namespace GitHub.Runner.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async Task UpdateRunnerConfigAsync_RefreshRunnerCredetialsFailure_ShouldReportTelemetry()
public async Task UpdateRunnerConfigAsync_RefreshRunnerCredentialsFailure_ShouldReportTelemetry()
{
using (var hc = new TestHostContext(this))
{
@@ -625,5 +626,53 @@ namespace GitHub.Runner.Tests.Listener
_configurationStore.Verify(x => x.SaveMigratedSettings(It.IsAny<RunnerSettings>()), Times.Never);
}
}
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async Task UpdateRunnerConfigAsync_UpdateRunnerCredentials_EnableDisableAuthMigration()
{
using (var hc = new TestHostContext(this))
{
hc.SetSingleton<IConfigurationStore>(_configurationStore.Object);
hc.SetSingleton<IRunnerServer>(_runnerServer.Object);
// Arrange
var setting = new RunnerSettings { AgentId = 1, AgentName = "agent1" };
_configurationStore.Setup(x => x.GetSettings()).Returns(setting);
var credData = new CredentialData
{
Scheme = "OAuth"
};
credData.Data.Add("ClientId", "12345");
credData.Data.Add("AuthorizationUrl", "https://example.com");
credData.Data.Add("AuthorizationUrlV2", "https://example2.com");
_configurationStore.Setup(x => x.GetCredentials()).Returns(credData);
IOUtil.SaveObject(setting, hc.GetConfigFile(WellKnownConfigFile.Runner));
IOUtil.SaveObject(credData, hc.GetConfigFile(WellKnownConfigFile.Credentials));
var encodedConfig = Convert.ToBase64String(Encoding.UTF8.GetBytes(StringUtil.ConvertToJson(credData)));
_runnerServer.Setup(x => x.RefreshRunnerConfigAsync(It.IsAny<int>(), It.Is<string>(s => s == "credentials"), It.IsAny<string>(), It.IsAny<CancellationToken>())).ReturnsAsync(encodedConfig);
var _runnerConfigUpdater = new RunnerConfigUpdater();
_runnerConfigUpdater.Initialize(hc);
Assert.False(hc.AllowAuthMigration);
var validRunnerQualifiedId = "valid/runner/qualifiedid/1";
var configType = "credentials";
var serviceType = "pipelines";
var configRefreshUrl = "http://example.com";
// Act
await _runnerConfigUpdater.UpdateRunnerConfigAsync(validRunnerQualifiedId, configType, serviceType, configRefreshUrl);
// Assert
_runnerServer.Verify(x => x.RefreshRunnerConfigAsync(1, "credentials", It.IsAny<string>(), It.IsAny<CancellationToken>()), Times.Once);
_runnerServer.Verify(x => x.UpdateAgentUpdateStateAsync(It.IsAny<int>(), It.IsAny<ulong>(), It.IsAny<string>(), It.Is<string>(s => s.Contains("Runner credentials updated successfully")), It.IsAny<CancellationToken>()), Times.Once);
_configurationStore.Verify(x => x.SaveMigratedCredential(It.IsAny<CredentialData>()), Times.Once);
Assert.True(hc.AllowAuthMigration);
}
}
}
}

View File

@@ -1,13 +1,13 @@
using GitHub.DistributedTask.WebApi;
using GitHub.Runner.Listener;
using GitHub.Runner.Listener.Configuration;
using Moq;
using System;
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using GitHub.DistributedTask.WebApi;
using GitHub.Runner.Listener;
using GitHub.Runner.Listener.Configuration;
using GitHub.Services.WebApi;
using Moq;
using Xunit;
using Pipelines = GitHub.DistributedTask.Pipelines;
namespace GitHub.Runner.Common.Tests.Listener
@@ -57,7 +57,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
//process 2 new job messages, and one cancel message
public async void TestRunAsync()
public async Task TestRunAsync()
{
using (var hc = new TestHostContext(this))
{
@@ -169,7 +169,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[MemberData(nameof(RunAsServiceTestData))]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void TestExecuteCommandForRunAsService(string[] args, bool configureAsService, Times expectedTimes)
public async Task TestExecuteCommandForRunAsService(string[] args, bool configureAsService, Times expectedTimes)
{
using (var hc = new TestHostContext(this))
{
@@ -177,6 +177,7 @@ namespace GitHub.Runner.Common.Tests.Listener
hc.SetSingleton<IPromptManager>(_promptManager.Object);
hc.SetSingleton<IMessageListener>(_messageListener.Object);
hc.SetSingleton<IConfigurationStore>(_configStore.Object);
hc.SetSingleton<IRunnerServer>(_runnerServer.Object);
hc.EnqueueInstance<IErrorThrottler>(_acquireJobThrottler.Object);
var command = new CommandSettings(hc, args);
@@ -201,7 +202,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void TestMachineProvisionerCLI()
public async Task TestMachineProvisionerCLI()
{
using (var hc = new TestHostContext(this))
{
@@ -209,6 +210,7 @@ namespace GitHub.Runner.Common.Tests.Listener
hc.SetSingleton<IPromptManager>(_promptManager.Object);
hc.SetSingleton<IMessageListener>(_messageListener.Object);
hc.SetSingleton<IConfigurationStore>(_configStore.Object);
hc.SetSingleton<IRunnerServer>(_runnerServer.Object);
hc.EnqueueInstance<IErrorThrottler>(_acquireJobThrottler.Object);
var command = new CommandSettings(hc, new[] { "run" });
@@ -235,7 +237,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void TestRunOnce()
public async Task TestRunOnce()
{
using (var hc = new TestHostContext(this))
{
@@ -332,7 +334,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void TestRunOnceOnlyTakeOneJobMessage()
public async Task TestRunOnceOnlyTakeOneJobMessage()
{
using (var hc = new TestHostContext(this))
{
@@ -433,7 +435,7 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void TestRunOnceHandleUpdateMessage()
public async Task TestRunOnceHandleUpdateMessage()
{
using (var hc = new TestHostContext(this))
{
@@ -523,13 +525,14 @@ namespace GitHub.Runner.Common.Tests.Listener
[Fact]
[Trait("Level", "L0")]
[Trait("Category", "Runner")]
public async void TestRemoveLocalRunnerConfig()
public async Task TestRemoveLocalRunnerConfig()
{
using (var hc = new TestHostContext(this))
{
hc.SetSingleton<IConfigurationManager>(_configurationManager.Object);
hc.SetSingleton<IConfigurationStore>(_configStore.Object);
hc.SetSingleton<IPromptManager>(_promptManager.Object);
hc.SetSingleton<IRunnerServer>(_runnerServer.Object);
hc.EnqueueInstance<IErrorThrottler>(_acquireJobThrottler.Object);
var command = new CommandSettings(hc, new[] { "remove", "--local" });

View File

@@ -103,8 +103,8 @@ namespace GitHub.Runner.Common.Tests
handler(this, new DelayEventArgs(delay, token));
}
// Delay zero
await Task.Delay(TimeSpan.Zero);
// Delay 10ms
await Task.Delay(TimeSpan.FromMilliseconds(10));
}
public T CreateService<T>() where T : class, IRunnerService