diff --git a/cmd/ghalistener/listener/listener.go b/cmd/ghalistener/listener/listener.go index 56da0a8f..f433cbb9 100644 --- a/cmd/ghalistener/listener/listener.go +++ b/cmd/ghalistener/listener/listener.go @@ -31,7 +31,7 @@ const ( type Client interface { GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*actions.AcquirableJobList, error) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*actions.RunnerScaleSetSession, error) - GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) + GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error) DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*actions.RunnerScaleSetSession, error) @@ -80,6 +80,7 @@ type Listener struct { // updated fields lastMessageID int64 // The ID of the last processed message. + maxCapacity int // The maximum number of runners that can be created. session *actions.RunnerScaleSetSession // The session for managing the runner scale set. } @@ -89,10 +90,11 @@ func New(config Config) (*Listener, error) { } listener := &Listener{ - scaleSetID: config.ScaleSetID, - client: config.Client, - logger: config.Logger, - metrics: metrics.Discard, + scaleSetID: config.ScaleSetID, + client: config.Client, + logger: config.Logger, + metrics: metrics.Discard, + maxCapacity: config.MaxRunners, } if config.Metrics != nil { @@ -267,7 +269,7 @@ func (l *Listener) createSession(ctx context.Context) error { func (l *Listener) getMessage(ctx context.Context) (*actions.RunnerScaleSetMessage, error) { l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID) - msg, err := l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) + msg, err := l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID, l.maxCapacity) if err == nil { // if NO error return msg, nil } @@ -283,7 +285,7 @@ func (l *Listener) getMessage(ctx context.Context) (*actions.RunnerScaleSetMessa l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID) - msg, err = l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) + msg, err = l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID, l.maxCapacity) if err != nil { // if NO error return nil, fmt.Errorf("failed to get next message after message session refresh: %w", err) } diff --git a/cmd/ghalistener/listener/listener_test.go b/cmd/ghalistener/listener/listener_test.go index 610abc40..c39a0773 100644 --- a/cmd/ghalistener/listener/listener_test.go +++ b/cmd/ghalistener/listener/listener_test.go @@ -123,13 +123,14 @@ func TestListener_getMessage(t *testing.T) { config := Config{ ScaleSetID: 1, Metrics: metrics.Discard, + MaxRunners: 10, } client := listenermocks.NewClient(t) want := &actions.RunnerScaleSetMessage{ MessageId: 1, } - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(want, nil).Once() + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(want, nil).Once() config.Client = client l, err := New(config) @@ -148,10 +149,11 @@ func TestListener_getMessage(t *testing.T) { config := Config{ ScaleSetID: 1, Metrics: metrics.Discard, + MaxRunners: 10, } client := listenermocks.NewClient(t) - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.HttpClientSideError{Code: http.StatusNotFound}).Once() + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.HttpClientSideError{Code: http.StatusNotFound}).Once() config.Client = client l, err := New(config) @@ -170,6 +172,7 @@ func TestListener_getMessage(t *testing.T) { config := Config{ ScaleSetID: 1, Metrics: metrics.Discard, + MaxRunners: 10, } client := listenermocks.NewClient(t) @@ -185,12 +188,12 @@ func TestListener_getMessage(t *testing.T) { } client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() want := &actions.RunnerScaleSetMessage{ MessageId: 1, } - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(want, nil).Once() + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(want, nil).Once() config.Client = client @@ -214,6 +217,7 @@ func TestListener_getMessage(t *testing.T) { config := Config{ ScaleSetID: 1, Metrics: metrics.Discard, + MaxRunners: 10, } client := listenermocks.NewClient(t) @@ -229,7 +233,7 @@ func TestListener_getMessage(t *testing.T) { } client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Twice() + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Twice() config.Client = client @@ -450,6 +454,7 @@ func TestListener_Listen(t *testing.T) { config := Config{ ScaleSetID: 1, Metrics: metrics.Discard, + MaxRunners: 10, } client := listenermocks.NewClient(t) @@ -470,7 +475,7 @@ func TestListener_Listen(t *testing.T) { MessageType: "RunnerScaleSetJobMessages", Statistics: &actions.RunnerScaleSetStatistic{}, } - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything). + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10). Return(msg, nil). Run( func(mock.Arguments) { diff --git a/cmd/ghalistener/listener/mocks/client.go b/cmd/ghalistener/listener/mocks/client.go index 9c3d38fd..a36c9344 100644 --- a/cmd/ghalistener/listener/mocks/client.go +++ b/cmd/ghalistener/listener/mocks/client.go @@ -123,25 +123,25 @@ func (_m *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) ( return r0, r1 } -// GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId -func (_m *Client) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) { - ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) +// GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity +func (_m *Client) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error) { + ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) var r0 *actions.RunnerScaleSetMessage var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*actions.RunnerScaleSetMessage, error)); ok { - return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) (*actions.RunnerScaleSetMessage, error)); ok { + return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *actions.RunnerScaleSetMessage); ok { - r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) *actions.RunnerScaleSetMessage); ok { + r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*actions.RunnerScaleSetMessage) } } - if rf, ok := ret.Get(1).(func(context.Context, string, string, int64) error); ok { - r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, int) error); ok { + r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) } else { r1 = ret.Error(1) } diff --git a/cmd/githubrunnerscalesetlistener/autoScalerMessageListener.go b/cmd/githubrunnerscalesetlistener/autoScalerMessageListener.go index 0d7f5a2b..26c5072d 100644 --- a/cmd/githubrunnerscalesetlistener/autoScalerMessageListener.go +++ b/cmd/githubrunnerscalesetlistener/autoScalerMessageListener.go @@ -129,7 +129,7 @@ func (m *AutoScalerClient) Close() error { return m.client.Close() } -func (m *AutoScalerClient) GetRunnerScaleSetMessage(ctx context.Context, handler func(msg *actions.RunnerScaleSetMessage) error) error { +func (m *AutoScalerClient) GetRunnerScaleSetMessage(ctx context.Context, handler func(msg *actions.RunnerScaleSetMessage) error, maxCapacity int) error { if m.initialMessage != nil { err := handler(m.initialMessage) if err != nil { @@ -141,7 +141,7 @@ func (m *AutoScalerClient) GetRunnerScaleSetMessage(ctx context.Context, handler } for { - message, err := m.client.GetMessage(ctx, m.lastMessageId) + message, err := m.client.GetMessage(ctx, m.lastMessageId, maxCapacity) if err != nil { return fmt.Errorf("get message failed from refreshing client. %w", err) } diff --git a/cmd/githubrunnerscalesetlistener/autoScalerMessageListener_test.go b/cmd/githubrunnerscalesetlistener/autoScalerMessageListener_test.go index 2d6ef711..c48a9a54 100644 --- a/cmd/githubrunnerscalesetlistener/autoScalerMessageListener_test.go +++ b/cmd/githubrunnerscalesetlistener/autoScalerMessageListener_test.go @@ -317,7 +317,7 @@ func TestGetRunnerScaleSetMessage(t *testing.T) { Statistics: &actions.RunnerScaleSetStatistic{}, } mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) - mockSessionClient.On("GetMessage", ctx, int64(0)).Return(&actions.RunnerScaleSetMessage{ + mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(&actions.RunnerScaleSetMessage{ MessageId: 1, MessageType: "test", Body: "test", @@ -332,7 +332,7 @@ func TestGetRunnerScaleSetMessage(t *testing.T) { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return nil - }) + }, 10) assert.NoError(t, err, "Error getting message") assert.Equal(t, int64(0), asClient.lastMessageId, "Initial message") @@ -340,7 +340,7 @@ func TestGetRunnerScaleSetMessage(t *testing.T) { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return nil - }) + }, 10) assert.NoError(t, err, "Error getting message") assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated") @@ -368,7 +368,7 @@ func TestGetRunnerScaleSetMessage_HandleFailed(t *testing.T) { Statistics: &actions.RunnerScaleSetStatistic{}, } mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) - mockSessionClient.On("GetMessage", ctx, int64(0)).Return(&actions.RunnerScaleSetMessage{ + mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(&actions.RunnerScaleSetMessage{ MessageId: 1, MessageType: "test", Body: "test", @@ -383,14 +383,14 @@ func TestGetRunnerScaleSetMessage_HandleFailed(t *testing.T) { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return nil - }) + }, 10) assert.NoError(t, err, "Error getting message") err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return fmt.Errorf("error") - }) + }, 10) assert.ErrorContains(t, err, "handle message failed. error", "Error getting message") assert.Equal(t, int64(0), asClient.lastMessageId, "Last message id should not be updated") @@ -419,7 +419,7 @@ func TestGetRunnerScaleSetMessage_HandleInitialMessage(t *testing.T) { TotalAssignedJobs: 2, }, } - mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) + mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything, mock.Anything).Return(session, nil) mockActionsClient.On("GetAcquirableJobs", ctx, 1).Return(&actions.AcquirableJobList{ Count: 1, Jobs: []actions.AcquirableJob{ @@ -439,7 +439,7 @@ func TestGetRunnerScaleSetMessage_HandleInitialMessage(t *testing.T) { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return nil - }) + }, 10) assert.NoError(t, err, "Error getting message") assert.Nil(t, asClient.initialMessage, "Initial message should be nil") @@ -488,7 +488,7 @@ func TestGetRunnerScaleSetMessage_HandleInitialMessageFailed(t *testing.T) { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return fmt.Errorf("error") - }) + }, 10) assert.ErrorContains(t, err, "fail to process initial message. error", "Error getting message") assert.NotNil(t, asClient.initialMessage, "Initial message should be nil") @@ -516,8 +516,8 @@ func TestGetRunnerScaleSetMessage_RetryUntilGetMessage(t *testing.T) { Statistics: &actions.RunnerScaleSetStatistic{}, } mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) - mockSessionClient.On("GetMessage", ctx, int64(0)).Return(nil, nil).Times(3) - mockSessionClient.On("GetMessage", ctx, int64(0)).Return(&actions.RunnerScaleSetMessage{ + mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(nil, nil).Times(3) + mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(&actions.RunnerScaleSetMessage{ MessageId: 1, MessageType: "test", Body: "test", @@ -532,13 +532,13 @@ func TestGetRunnerScaleSetMessage_RetryUntilGetMessage(t *testing.T) { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return nil - }) + }, 10) assert.NoError(t, err, "Error getting initial message") err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return nil - }) + }, 10) assert.NoError(t, err, "Error getting message") assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated") @@ -565,7 +565,7 @@ func TestGetRunnerScaleSetMessage_ErrorOnGetMessage(t *testing.T) { Statistics: &actions.RunnerScaleSetStatistic{}, } mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) - mockSessionClient.On("GetMessage", ctx, int64(0)).Return(nil, fmt.Errorf("error")) + mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(nil, fmt.Errorf("error")) asClient, err := NewAutoScalerClient(ctx, mockActionsClient, &logger, 1, func(asc *AutoScalerClient) { asc.client = mockSessionClient @@ -575,12 +575,12 @@ func TestGetRunnerScaleSetMessage_ErrorOnGetMessage(t *testing.T) { // process initial message err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { return nil - }) + }, 10) assert.NoError(t, err, "Error getting initial message") err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { return fmt.Errorf("Should not be called") - }) + }, 10) assert.ErrorContains(t, err, "get message failed from refreshing client. error", "Error should be returned") assert.Equal(t, int64(0), asClient.lastMessageId, "Last message id should be updated") @@ -608,7 +608,7 @@ func TestDeleteRunnerScaleSetMessage_Error(t *testing.T) { Statistics: &actions.RunnerScaleSetStatistic{}, } mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) - mockSessionClient.On("GetMessage", ctx, int64(0)).Return(&actions.RunnerScaleSetMessage{ + mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(&actions.RunnerScaleSetMessage{ MessageId: 1, MessageType: "test", Body: "test", @@ -623,13 +623,13 @@ func TestDeleteRunnerScaleSetMessage_Error(t *testing.T) { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return nil - }) + }, 10) assert.NoError(t, err, "Error getting initial message") err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) return nil - }) + }, 10) assert.ErrorContains(t, err, "delete message failed from refreshing client. error", "Error getting message") assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated") diff --git a/cmd/githubrunnerscalesetlistener/autoScalerService.go b/cmd/githubrunnerscalesetlistener/autoScalerService.go index b8e14521..c3097212 100644 --- a/cmd/githubrunnerscalesetlistener/autoScalerService.go +++ b/cmd/githubrunnerscalesetlistener/autoScalerService.go @@ -89,7 +89,7 @@ func (s *Service) Start() error { s.logger.Info("service is stopped.") return nil default: - err := s.rsClient.GetRunnerScaleSetMessage(s.ctx, s.processMessage) + err := s.rsClient.GetRunnerScaleSetMessage(s.ctx, s.processMessage, s.settings.MaxRunners) if err != nil { return fmt.Errorf("could not get and process message. %w", err) } diff --git a/cmd/githubrunnerscalesetlistener/autoScalerService_test.go b/cmd/githubrunnerscalesetlistener/autoScalerService_test.go index d0e54545..9a353d16 100644 --- a/cmd/githubrunnerscalesetlistener/autoScalerService_test.go +++ b/cmd/githubrunnerscalesetlistener/autoScalerService_test.go @@ -64,7 +64,7 @@ func TestStart(t *testing.T) { ) require.NoError(t, err) - mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Run(func(args mock.Arguments) { cancel() }).Return(nil).Once() + mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Run(func(mock.Arguments) { cancel() }).Return(nil).Once() err = service.Start() @@ -98,7 +98,7 @@ func TestStart_ScaleToMinRunners(t *testing.T) { ) require.NoError(t, err) - mockRsClient.On("GetRunnerScaleSetMessage", ctx, mock.Anything).Run(func(args mock.Arguments) { + mockRsClient.On("GetRunnerScaleSetMessage", ctx, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { _ = service.scaleForAssignedJobCount(5) }).Return(nil) @@ -137,7 +137,7 @@ func TestStart_ScaleToMinRunnersFailed(t *testing.T) { require.NoError(t, err) c := mockKubeManager.On("ScaleEphemeralRunnerSet", ctx, service.settings.Namespace, service.settings.ResourceName, 5).Return(fmt.Errorf("error")).Once() - mockRsClient.On("GetRunnerScaleSetMessage", ctx, mock.Anything).Run(func(args mock.Arguments) { + mockRsClient.On("GetRunnerScaleSetMessage", ctx, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { _ = service.scaleForAssignedJobCount(5) }).Return(c.ReturnArguments.Get(0)) @@ -172,8 +172,8 @@ func TestStart_GetMultipleMessages(t *testing.T) { ) require.NoError(t, err) - mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Return(nil).Times(5) - mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Run(func(args mock.Arguments) { cancel() }).Return(nil).Once() + mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Return(nil).Times(5) + mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { cancel() }).Return(nil).Once() err = service.Start() @@ -207,8 +207,8 @@ func TestStart_ErrorOnMessage(t *testing.T) { ) require.NoError(t, err) - mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Return(nil).Times(2) - mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Return(fmt.Errorf("error")).Once() + mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Return(nil).Times(2) + mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Return(fmt.Errorf("error")).Once() err = service.Start() diff --git a/cmd/githubrunnerscalesetlistener/messageListener.go b/cmd/githubrunnerscalesetlistener/messageListener.go index 0f01db58..e90aa454 100644 --- a/cmd/githubrunnerscalesetlistener/messageListener.go +++ b/cmd/githubrunnerscalesetlistener/messageListener.go @@ -8,6 +8,6 @@ import ( //go:generate mockery --inpackage --name=RunnerScaleSetClient type RunnerScaleSetClient interface { - GetRunnerScaleSetMessage(ctx context.Context, handler func(msg *actions.RunnerScaleSetMessage) error) error + GetRunnerScaleSetMessage(ctx context.Context, handler func(msg *actions.RunnerScaleSetMessage) error, maxCapacity int) error AcquireJobsForRunnerScaleSet(ctx context.Context, requestIds []int64) error } diff --git a/cmd/githubrunnerscalesetlistener/mock_RunnerScaleSetClient.go b/cmd/githubrunnerscalesetlistener/mock_RunnerScaleSetClient.go index 80ba900a..a6f6a5d1 100644 --- a/cmd/githubrunnerscalesetlistener/mock_RunnerScaleSetClient.go +++ b/cmd/githubrunnerscalesetlistener/mock_RunnerScaleSetClient.go @@ -29,13 +29,13 @@ func (_m *MockRunnerScaleSetClient) AcquireJobsForRunnerScaleSet(ctx context.Con return r0 } -// GetRunnerScaleSetMessage provides a mock function with given fields: ctx, handler -func (_m *MockRunnerScaleSetClient) GetRunnerScaleSetMessage(ctx context.Context, handler func(*actions.RunnerScaleSetMessage) error) error { - ret := _m.Called(ctx, handler) +// GetRunnerScaleSetMessage provides a mock function with given fields: ctx, handler, maxCapacity +func (_m *MockRunnerScaleSetClient) GetRunnerScaleSetMessage(ctx context.Context, handler func(*actions.RunnerScaleSetMessage) error, maxCapacity int) error { + ret := _m.Called(ctx, handler, maxCapacity) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, func(*actions.RunnerScaleSetMessage) error) error); ok { - r0 = rf(ctx, handler) + if rf, ok := ret.Get(0).(func(context.Context, func(*actions.RunnerScaleSetMessage) error, int) error); ok { + r0 = rf(ctx, handler, maxCapacity) } else { r0 = ret.Error(0) } diff --git a/cmd/githubrunnerscalesetlistener/sessionrefreshingclient.go b/cmd/githubrunnerscalesetlistener/sessionrefreshingclient.go index 11df7e21..f3262c15 100644 --- a/cmd/githubrunnerscalesetlistener/sessionrefreshingclient.go +++ b/cmd/githubrunnerscalesetlistener/sessionrefreshingclient.go @@ -24,8 +24,12 @@ func newSessionClient(client actions.ActionsService, logger *logr.Logger, sessio } } -func (m *SessionRefreshingClient) GetMessage(ctx context.Context, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) { - message, err := m.client.GetMessage(ctx, m.session.MessageQueueUrl, m.session.MessageQueueAccessToken, lastMessageId) +func (m *SessionRefreshingClient) GetMessage(ctx context.Context, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error) { + if maxCapacity < 0 { + return nil, fmt.Errorf("maxCapacity must be greater than or equal to 0") + } + + message, err := m.client.GetMessage(ctx, m.session.MessageQueueUrl, m.session.MessageQueueAccessToken, lastMessageId, maxCapacity) if err == nil { return message, nil } @@ -42,7 +46,7 @@ func (m *SessionRefreshingClient) GetMessage(ctx context.Context, lastMessageId } m.session = session - message, err = m.client.GetMessage(ctx, m.session.MessageQueueUrl, m.session.MessageQueueAccessToken, lastMessageId) + message, err = m.client.GetMessage(ctx, m.session.MessageQueueUrl, m.session.MessageQueueAccessToken, lastMessageId, maxCapacity) if err != nil { return nil, fmt.Errorf("delete message failed after refresh message session. %w", err) } diff --git a/cmd/githubrunnerscalesetlistener/sessionrefreshingclient_test.go b/cmd/githubrunnerscalesetlistener/sessionrefreshingclient_test.go index 1423a0ce..1cdfb6c7 100644 --- a/cmd/githubrunnerscalesetlistener/sessionrefreshingclient_test.go +++ b/cmd/githubrunnerscalesetlistener/sessionrefreshingclient_test.go @@ -31,17 +31,17 @@ func TestGetMessage(t *testing.T) { }, } - mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(nil, nil).Once() - mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(&actions.RunnerScaleSetMessage{MessageId: 1}, nil).Once() + mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(nil, nil).Once() + mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(&actions.RunnerScaleSetMessage{MessageId: 1}, nil).Once() client := newSessionClient(mockActionsClient, &logger, session) - msg, err := client.GetMessage(ctx, 0) + msg, err := client.GetMessage(ctx, 0, 10) require.NoError(t, err, "GetMessage should not return an error") assert.Nil(t, msg, "GetMessage should return nil message") - msg, err = client.GetMessage(ctx, 0) + msg, err = client.GetMessage(ctx, 0, 10) require.NoError(t, err, "GetMessage should not return an error") assert.Equal(t, int64(1), msg.MessageId, "GetMessage should return a message with id 1") @@ -146,11 +146,11 @@ func TestGetMessage_Error(t *testing.T) { }, } - mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(nil, fmt.Errorf("error")).Once() + mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(nil, fmt.Errorf("error")).Once() client := newSessionClient(mockActionsClient, &logger, session) - msg, err := client.GetMessage(ctx, 0) + msg, err := client.GetMessage(ctx, 0, 10) assert.ErrorContains(t, err, "get message failed. error", "GetMessage should return an error") assert.Nil(t, msg, "GetMessage should return nil message") assert.True(t, mockActionsClient.AssertExpectations(t), "All expected calls to mockActionsClient should have been made") @@ -227,8 +227,8 @@ func TestGetMessage_RefreshToken(t *testing.T) { Id: 1, }, } - mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() - mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, "token2", int64(0)).Return(&actions.RunnerScaleSetMessage{ + mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() + mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, "token2", int64(0), 10).Return(&actions.RunnerScaleSetMessage{ MessageId: 1, MessageType: "test", Body: "test", @@ -243,7 +243,7 @@ func TestGetMessage_RefreshToken(t *testing.T) { }, nil).Once() client := newSessionClient(mockActionsClient, &logger, session) - msg, err := client.GetMessage(ctx, 0) + msg, err := client.GetMessage(ctx, 0, 10) assert.NoError(t, err, "Error getting message") assert.Equal(t, int64(1), msg.MessageId, "message id should be updated") assert.Equal(t, "token2", client.session.MessageQueueAccessToken, "Message queue access token should be updated") @@ -340,11 +340,11 @@ func TestGetMessage_RefreshToken_Failed(t *testing.T) { Id: 1, }, } - mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() + mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() mockActionsClient.On("RefreshMessageSession", ctx, session.RunnerScaleSet.Id, session.SessionId).Return(nil, fmt.Errorf("error")) client := newSessionClient(mockActionsClient, &logger, session) - msg, err := client.GetMessage(ctx, 0) + msg, err := client.GetMessage(ctx, 0, 10) assert.ErrorContains(t, err, "refresh message session failed. error", "Error should be returned") assert.Nil(t, msg, "Message should be nil") assert.Equal(t, "token", client.session.MessageQueueAccessToken, "Message queue access token should not be updated") diff --git a/github/actions/client.go b/github/actions/client.go index 3470384f..18a078cf 100644 --- a/github/actions/client.go +++ b/github/actions/client.go @@ -29,6 +29,9 @@ const ( apiVersionQueryParam = "api-version=6.0-preview" ) +// Header used to propagate capacity information to the back-end +const HeaderScaleSetMaxCapacity = "X-ScaleSetMaxCapacity" + //go:generate mockery --inpackage --name=ActionsService type ActionsService interface { GetRunnerScaleSet(ctx context.Context, runnerGroupId int, runnerScaleSetName string) (*RunnerScaleSet, error) @@ -45,7 +48,7 @@ type ActionsService interface { AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*AcquirableJobList, error) - GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) + GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error) DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *RunnerScaleSetJitRunnerSetting, scaleSetId int) (*RunnerScaleSetJitRunnerConfig, error) @@ -104,6 +107,8 @@ type Client struct { proxyFunc ProxyFunc } +var _ ActionsService = &Client{} + type ProxyFunc func(req *http.Request) (*url.URL, error) type ClientOption func(*Client) @@ -543,7 +548,7 @@ func (c *Client) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int) return nil } -func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) { +func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error) { u, err := url.Parse(messageQueueUrl) if err != nil { return nil, err @@ -555,6 +560,10 @@ func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAc u.RawQuery = q.Encode() } + if maxCapacity < 0 { + return nil, fmt.Errorf("maxCapacity must be greater than or equal to 0") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { return nil, err @@ -563,6 +572,7 @@ func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAc req.Header.Set("Accept", "application/json; api-version=6.0-preview") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", messageQueueAccessToken)) req.Header.Set("User-Agent", c.userAgent.String()) + req.Header.Set(HeaderScaleSetMaxCapacity, strconv.Itoa(maxCapacity)) resp, err := c.Do(req) if err != nil { diff --git a/github/actions/client_runner_scale_set_message_test.go b/github/actions/client_runner_scale_set_message_test.go index cb67c310..8b15a835 100644 --- a/github/actions/client_runner_scale_set_message_test.go +++ b/github/actions/client_runner_scale_set_message_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "net/http" + "strconv" "testing" "time" @@ -35,7 +36,7 @@ func TestGetMessage(t *testing.T) { client, err := actions.NewClient(s.configURLForOrg("my-org"), auth) require.NoError(t, err) - got, err := client.GetMessage(ctx, s.URL, token, 0) + got, err := client.GetMessage(ctx, s.URL, token, 0, 10) require.NoError(t, err) assert.Equal(t, want, got) }) @@ -52,7 +53,7 @@ func TestGetMessage(t *testing.T) { client, err := actions.NewClient(s.configURLForOrg("my-org"), auth) require.NoError(t, err) - got, err := client.GetMessage(ctx, s.URL, token, 1) + got, err := client.GetMessage(ctx, s.URL, token, 1, 10) require.NoError(t, err) assert.Equal(t, want, got) }) @@ -76,7 +77,7 @@ func TestGetMessage(t *testing.T) { ) require.NoError(t, err) - _, err = client.GetMessage(ctx, server.URL, token, 0) + _, err = client.GetMessage(ctx, server.URL, token, 0, 10) assert.NotNil(t, err) assert.Equalf(t, actualRetry, expectedRetry, "A retry was expected after the first request but got: %v", actualRetry) }) @@ -89,7 +90,7 @@ func TestGetMessage(t *testing.T) { client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) - _, err = client.GetMessage(ctx, server.URL, token, 0) + _, err = client.GetMessage(ctx, server.URL, token, 0, 10) require.NotNil(t, err) var expectedErr *actions.MessageQueueTokenExpiredError @@ -108,7 +109,7 @@ func TestGetMessage(t *testing.T) { client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) - _, err = client.GetMessage(ctx, server.URL, token, 0) + _, err = client.GetMessage(ctx, server.URL, token, 0, 10) require.NotNil(t, err) assert.Equal(t, want.Error(), err.Error()) }) @@ -122,9 +123,35 @@ func TestGetMessage(t *testing.T) { client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) - _, err = client.GetMessage(ctx, server.URL, token, 0) + _, err = client.GetMessage(ctx, server.URL, token, 0, 10) assert.NotNil(t, err) }) + + t.Run("Capacity error handling", func(t *testing.T) { + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hc := r.Header.Get(actions.HeaderScaleSetMaxCapacity) + c, err := strconv.Atoi(hc) + require.NoError(t, err) + assert.GreaterOrEqual(t, c, 0) + + w.WriteHeader(http.StatusBadRequest) + w.Header().Set("Content-Type", "text/plain") + })) + + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) + require.NoError(t, err) + + _, err = client.GetMessage(ctx, server.URL, token, 0, -1) + require.Error(t, err) + // Ensure we don't send requests with negative capacity + assert.False(t, errors.Is(err, &actions.ActionsError{})) + + _, err = client.GetMessage(ctx, server.URL, token, 0, 0) + assert.Error(t, err) + var expectedErr *actions.ActionsError + assert.ErrorAs(t, err, &expectedErr) + assert.Equal(t, http.StatusBadRequest, expectedErr.StatusCode) + }) } func TestDeleteMessage(t *testing.T) { diff --git a/github/actions/errors.go b/github/actions/errors.go index 736520bd..86ba5b6d 100644 --- a/github/actions/errors.go +++ b/github/actions/errors.go @@ -47,7 +47,6 @@ func (e *ActionsError) IsException(target string) bool { if ex, ok := e.Err.(*ActionsExceptionError); ok { return strings.Contains(ex.ExceptionName, target) } - return false } diff --git a/github/actions/fake/client.go b/github/actions/fake/client.go index de51a278..a108b902 100644 --- a/github/actions/fake/client.go +++ b/github/actions/fake/client.go @@ -259,7 +259,7 @@ func (f *FakeClient) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int return f.getAcquirableJobsResult.AcquirableJobList, f.getAcquirableJobsResult.err } -func (f *FakeClient) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) { +func (f *FakeClient) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error) { return f.getMessageResult.RunnerScaleSetMessage, f.getMessageResult.err } diff --git a/github/actions/mock_ActionsService.go b/github/actions/mock_ActionsService.go index 0216cf30..849f2c19 100644 --- a/github/actions/mock_ActionsService.go +++ b/github/actions/mock_ActionsService.go @@ -186,25 +186,25 @@ func (_m *MockActionsService) GetAcquirableJobs(ctx context.Context, runnerScale return r0, r1 } -// GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId -func (_m *MockActionsService) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) { - ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) +// GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity +func (_m *MockActionsService) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error) { + ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) var r0 *RunnerScaleSetMessage var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*RunnerScaleSetMessage, error)); ok { - return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) (*RunnerScaleSetMessage, error)); ok { + return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *RunnerScaleSetMessage); ok { - r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) *RunnerScaleSetMessage); ok { + r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*RunnerScaleSetMessage) } } - if rf, ok := ret.Get(1).(func(context.Context, string, string, int64) error); ok { - r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, int) error); ok { + r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) } else { r1 = ret.Error(1) } diff --git a/github/actions/mock_SessionService.go b/github/actions/mock_SessionService.go index ed403eee..f587cac8 100644 --- a/github/actions/mock_SessionService.go +++ b/github/actions/mock_SessionService.go @@ -67,25 +67,25 @@ func (_m *MockSessionService) DeleteMessage(ctx context.Context, messageId int64 return r0 } -// GetMessage provides a mock function with given fields: ctx, lastMessageId -func (_m *MockSessionService) GetMessage(ctx context.Context, lastMessageId int64) (*RunnerScaleSetMessage, error) { - ret := _m.Called(ctx, lastMessageId) +// GetMessage provides a mock function with given fields: ctx, lastMessageId, maxCapacity +func (_m *MockSessionService) GetMessage(ctx context.Context, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error) { + ret := _m.Called(ctx, lastMessageId, maxCapacity) var r0 *RunnerScaleSetMessage var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int64) (*RunnerScaleSetMessage, error)); ok { - return rf(ctx, lastMessageId) + if rf, ok := ret.Get(0).(func(context.Context, int64, int) (*RunnerScaleSetMessage, error)); ok { + return rf(ctx, lastMessageId, maxCapacity) } - if rf, ok := ret.Get(0).(func(context.Context, int64) *RunnerScaleSetMessage); ok { - r0 = rf(ctx, lastMessageId) + if rf, ok := ret.Get(0).(func(context.Context, int64, int) *RunnerScaleSetMessage); ok { + r0 = rf(ctx, lastMessageId, maxCapacity) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*RunnerScaleSetMessage) } } - if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { - r1 = rf(ctx, lastMessageId) + if rf, ok := ret.Get(1).(func(context.Context, int64, int) error); ok { + r1 = rf(ctx, lastMessageId, maxCapacity) } else { r1 = ret.Error(1) } diff --git a/github/actions/sessionservice.go b/github/actions/sessionservice.go index 6ae20fa0..21311aa0 100644 --- a/github/actions/sessionservice.go +++ b/github/actions/sessionservice.go @@ -7,7 +7,7 @@ import ( //go:generate mockery --inpackage --name=SessionService type SessionService interface { - GetMessage(ctx context.Context, lastMessageId int64) (*RunnerScaleSetMessage, error) + GetMessage(ctx context.Context, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error) DeleteMessage(ctx context.Context, messageId int64) error AcquireJobs(ctx context.Context, requestIds []int64) ([]int64, error) io.Closer