Propagate max capacity information to the actions back-end (#3431)

This commit is contained in:
Nikola Jokic
2024-04-16 14:00:40 +02:00
committed by GitHub
parent 8075e5ee74
commit 4ee49fee14
18 changed files with 147 additions and 100 deletions

View File

@@ -31,7 +31,7 @@ const (
type Client interface { type Client interface {
GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*actions.AcquirableJobList, error) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*actions.AcquirableJobList, error)
CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*actions.RunnerScaleSetSession, error) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*actions.RunnerScaleSetSession, error)
GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error)
DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error
AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error)
RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*actions.RunnerScaleSetSession, error) RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*actions.RunnerScaleSetSession, error)
@@ -80,6 +80,7 @@ type Listener struct {
// updated fields // updated fields
lastMessageID int64 // The ID of the last processed message. lastMessageID int64 // The ID of the last processed message.
maxCapacity int // The maximum number of runners that can be created.
session *actions.RunnerScaleSetSession // The session for managing the runner scale set. session *actions.RunnerScaleSetSession // The session for managing the runner scale set.
} }
@@ -89,10 +90,11 @@ func New(config Config) (*Listener, error) {
} }
listener := &Listener{ listener := &Listener{
scaleSetID: config.ScaleSetID, scaleSetID: config.ScaleSetID,
client: config.Client, client: config.Client,
logger: config.Logger, logger: config.Logger,
metrics: metrics.Discard, metrics: metrics.Discard,
maxCapacity: config.MaxRunners,
} }
if config.Metrics != nil { if config.Metrics != nil {
@@ -267,7 +269,7 @@ func (l *Listener) createSession(ctx context.Context) error {
func (l *Listener) getMessage(ctx context.Context) (*actions.RunnerScaleSetMessage, error) { func (l *Listener) getMessage(ctx context.Context) (*actions.RunnerScaleSetMessage, error) {
l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID) l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID)
msg, err := l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) msg, err := l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID, l.maxCapacity)
if err == nil { // if NO error if err == nil { // if NO error
return msg, nil return msg, nil
} }
@@ -283,7 +285,7 @@ func (l *Listener) getMessage(ctx context.Context) (*actions.RunnerScaleSetMessa
l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID) l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID)
msg, err = l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) msg, err = l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID, l.maxCapacity)
if err != nil { // if NO error if err != nil { // if NO error
return nil, fmt.Errorf("failed to get next message after message session refresh: %w", err) return nil, fmt.Errorf("failed to get next message after message session refresh: %w", err)
} }

View File

@@ -123,13 +123,14 @@ func TestListener_getMessage(t *testing.T) {
config := Config{ config := Config{
ScaleSetID: 1, ScaleSetID: 1,
Metrics: metrics.Discard, Metrics: metrics.Discard,
MaxRunners: 10,
} }
client := listenermocks.NewClient(t) client := listenermocks.NewClient(t)
want := &actions.RunnerScaleSetMessage{ want := &actions.RunnerScaleSetMessage{
MessageId: 1, MessageId: 1,
} }
client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(want, nil).Once() client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(want, nil).Once()
config.Client = client config.Client = client
l, err := New(config) l, err := New(config)
@@ -148,10 +149,11 @@ func TestListener_getMessage(t *testing.T) {
config := Config{ config := Config{
ScaleSetID: 1, ScaleSetID: 1,
Metrics: metrics.Discard, Metrics: metrics.Discard,
MaxRunners: 10,
} }
client := listenermocks.NewClient(t) client := listenermocks.NewClient(t)
client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.HttpClientSideError{Code: http.StatusNotFound}).Once() client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.HttpClientSideError{Code: http.StatusNotFound}).Once()
config.Client = client config.Client = client
l, err := New(config) l, err := New(config)
@@ -170,6 +172,7 @@ func TestListener_getMessage(t *testing.T) {
config := Config{ config := Config{
ScaleSetID: 1, ScaleSetID: 1,
Metrics: metrics.Discard, Metrics: metrics.Discard,
MaxRunners: 10,
} }
client := listenermocks.NewClient(t) client := listenermocks.NewClient(t)
@@ -185,12 +188,12 @@ func TestListener_getMessage(t *testing.T) {
} }
client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once()
client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once()
want := &actions.RunnerScaleSetMessage{ want := &actions.RunnerScaleSetMessage{
MessageId: 1, MessageId: 1,
} }
client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(want, nil).Once() client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(want, nil).Once()
config.Client = client config.Client = client
@@ -214,6 +217,7 @@ func TestListener_getMessage(t *testing.T) {
config := Config{ config := Config{
ScaleSetID: 1, ScaleSetID: 1,
Metrics: metrics.Discard, Metrics: metrics.Discard,
MaxRunners: 10,
} }
client := listenermocks.NewClient(t) client := listenermocks.NewClient(t)
@@ -229,7 +233,7 @@ func TestListener_getMessage(t *testing.T) {
} }
client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once()
client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Twice() client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Twice()
config.Client = client config.Client = client
@@ -450,6 +454,7 @@ func TestListener_Listen(t *testing.T) {
config := Config{ config := Config{
ScaleSetID: 1, ScaleSetID: 1,
Metrics: metrics.Discard, Metrics: metrics.Discard,
MaxRunners: 10,
} }
client := listenermocks.NewClient(t) client := listenermocks.NewClient(t)
@@ -470,7 +475,7 @@ func TestListener_Listen(t *testing.T) {
MessageType: "RunnerScaleSetJobMessages", MessageType: "RunnerScaleSetJobMessages",
Statistics: &actions.RunnerScaleSetStatistic{}, Statistics: &actions.RunnerScaleSetStatistic{},
} }
client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything). client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).
Return(msg, nil). Return(msg, nil).
Run( Run(
func(mock.Arguments) { func(mock.Arguments) {

View File

@@ -123,25 +123,25 @@ func (_m *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (
return r0, r1 return r0, r1
} }
// GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId // GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity
func (_m *Client) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) { func (_m *Client) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error) {
ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity)
var r0 *actions.RunnerScaleSetMessage var r0 *actions.RunnerScaleSetMessage
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*actions.RunnerScaleSetMessage, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) (*actions.RunnerScaleSetMessage, error)); ok {
return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity)
} }
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *actions.RunnerScaleSetMessage); ok { if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) *actions.RunnerScaleSetMessage); ok {
r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*actions.RunnerScaleSetMessage) r0 = ret.Get(0).(*actions.RunnerScaleSetMessage)
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, int) error); ok {
r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }

View File

@@ -129,7 +129,7 @@ func (m *AutoScalerClient) Close() error {
return m.client.Close() return m.client.Close()
} }
func (m *AutoScalerClient) GetRunnerScaleSetMessage(ctx context.Context, handler func(msg *actions.RunnerScaleSetMessage) error) error { func (m *AutoScalerClient) GetRunnerScaleSetMessage(ctx context.Context, handler func(msg *actions.RunnerScaleSetMessage) error, maxCapacity int) error {
if m.initialMessage != nil { if m.initialMessage != nil {
err := handler(m.initialMessage) err := handler(m.initialMessage)
if err != nil { if err != nil {
@@ -141,7 +141,7 @@ func (m *AutoScalerClient) GetRunnerScaleSetMessage(ctx context.Context, handler
} }
for { for {
message, err := m.client.GetMessage(ctx, m.lastMessageId) message, err := m.client.GetMessage(ctx, m.lastMessageId, maxCapacity)
if err != nil { if err != nil {
return fmt.Errorf("get message failed from refreshing client. %w", err) return fmt.Errorf("get message failed from refreshing client. %w", err)
} }

View File

@@ -317,7 +317,7 @@ func TestGetRunnerScaleSetMessage(t *testing.T) {
Statistics: &actions.RunnerScaleSetStatistic{}, Statistics: &actions.RunnerScaleSetStatistic{},
} }
mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil)
mockSessionClient.On("GetMessage", ctx, int64(0)).Return(&actions.RunnerScaleSetMessage{ mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(&actions.RunnerScaleSetMessage{
MessageId: 1, MessageId: 1,
MessageType: "test", MessageType: "test",
Body: "test", Body: "test",
@@ -332,7 +332,7 @@ func TestGetRunnerScaleSetMessage(t *testing.T) {
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return nil return nil
}) }, 10)
assert.NoError(t, err, "Error getting message") assert.NoError(t, err, "Error getting message")
assert.Equal(t, int64(0), asClient.lastMessageId, "Initial message") assert.Equal(t, int64(0), asClient.lastMessageId, "Initial message")
@@ -340,7 +340,7 @@ func TestGetRunnerScaleSetMessage(t *testing.T) {
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return nil return nil
}) }, 10)
assert.NoError(t, err, "Error getting message") assert.NoError(t, err, "Error getting message")
assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated") assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated")
@@ -368,7 +368,7 @@ func TestGetRunnerScaleSetMessage_HandleFailed(t *testing.T) {
Statistics: &actions.RunnerScaleSetStatistic{}, Statistics: &actions.RunnerScaleSetStatistic{},
} }
mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil)
mockSessionClient.On("GetMessage", ctx, int64(0)).Return(&actions.RunnerScaleSetMessage{ mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(&actions.RunnerScaleSetMessage{
MessageId: 1, MessageId: 1,
MessageType: "test", MessageType: "test",
Body: "test", Body: "test",
@@ -383,14 +383,14 @@ func TestGetRunnerScaleSetMessage_HandleFailed(t *testing.T) {
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return nil return nil
}) }, 10)
assert.NoError(t, err, "Error getting message") assert.NoError(t, err, "Error getting message")
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return fmt.Errorf("error") return fmt.Errorf("error")
}) }, 10)
assert.ErrorContains(t, err, "handle message failed. error", "Error getting message") assert.ErrorContains(t, err, "handle message failed. error", "Error getting message")
assert.Equal(t, int64(0), asClient.lastMessageId, "Last message id should not be updated") assert.Equal(t, int64(0), asClient.lastMessageId, "Last message id should not be updated")
@@ -419,7 +419,7 @@ func TestGetRunnerScaleSetMessage_HandleInitialMessage(t *testing.T) {
TotalAssignedJobs: 2, TotalAssignedJobs: 2,
}, },
} }
mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything, mock.Anything).Return(session, nil)
mockActionsClient.On("GetAcquirableJobs", ctx, 1).Return(&actions.AcquirableJobList{ mockActionsClient.On("GetAcquirableJobs", ctx, 1).Return(&actions.AcquirableJobList{
Count: 1, Count: 1,
Jobs: []actions.AcquirableJob{ Jobs: []actions.AcquirableJob{
@@ -439,7 +439,7 @@ func TestGetRunnerScaleSetMessage_HandleInitialMessage(t *testing.T) {
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return nil return nil
}) }, 10)
assert.NoError(t, err, "Error getting message") assert.NoError(t, err, "Error getting message")
assert.Nil(t, asClient.initialMessage, "Initial message should be nil") assert.Nil(t, asClient.initialMessage, "Initial message should be nil")
@@ -488,7 +488,7 @@ func TestGetRunnerScaleSetMessage_HandleInitialMessageFailed(t *testing.T) {
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return fmt.Errorf("error") return fmt.Errorf("error")
}) }, 10)
assert.ErrorContains(t, err, "fail to process initial message. error", "Error getting message") assert.ErrorContains(t, err, "fail to process initial message. error", "Error getting message")
assert.NotNil(t, asClient.initialMessage, "Initial message should be nil") assert.NotNil(t, asClient.initialMessage, "Initial message should be nil")
@@ -516,8 +516,8 @@ func TestGetRunnerScaleSetMessage_RetryUntilGetMessage(t *testing.T) {
Statistics: &actions.RunnerScaleSetStatistic{}, Statistics: &actions.RunnerScaleSetStatistic{},
} }
mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil)
mockSessionClient.On("GetMessage", ctx, int64(0)).Return(nil, nil).Times(3) mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(nil, nil).Times(3)
mockSessionClient.On("GetMessage", ctx, int64(0)).Return(&actions.RunnerScaleSetMessage{ mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(&actions.RunnerScaleSetMessage{
MessageId: 1, MessageId: 1,
MessageType: "test", MessageType: "test",
Body: "test", Body: "test",
@@ -532,13 +532,13 @@ func TestGetRunnerScaleSetMessage_RetryUntilGetMessage(t *testing.T) {
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return nil return nil
}) }, 10)
assert.NoError(t, err, "Error getting initial message") assert.NoError(t, err, "Error getting initial message")
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return nil return nil
}) }, 10)
assert.NoError(t, err, "Error getting message") assert.NoError(t, err, "Error getting message")
assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated") assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated")
@@ -565,7 +565,7 @@ func TestGetRunnerScaleSetMessage_ErrorOnGetMessage(t *testing.T) {
Statistics: &actions.RunnerScaleSetStatistic{}, Statistics: &actions.RunnerScaleSetStatistic{},
} }
mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil)
mockSessionClient.On("GetMessage", ctx, int64(0)).Return(nil, fmt.Errorf("error")) mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(nil, fmt.Errorf("error"))
asClient, err := NewAutoScalerClient(ctx, mockActionsClient, &logger, 1, func(asc *AutoScalerClient) { asClient, err := NewAutoScalerClient(ctx, mockActionsClient, &logger, 1, func(asc *AutoScalerClient) {
asc.client = mockSessionClient asc.client = mockSessionClient
@@ -575,12 +575,12 @@ func TestGetRunnerScaleSetMessage_ErrorOnGetMessage(t *testing.T) {
// process initial message // process initial message
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
return nil return nil
}) }, 10)
assert.NoError(t, err, "Error getting initial message") assert.NoError(t, err, "Error getting initial message")
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
return fmt.Errorf("Should not be called") return fmt.Errorf("Should not be called")
}) }, 10)
assert.ErrorContains(t, err, "get message failed from refreshing client. error", "Error should be returned") assert.ErrorContains(t, err, "get message failed from refreshing client. error", "Error should be returned")
assert.Equal(t, int64(0), asClient.lastMessageId, "Last message id should be updated") assert.Equal(t, int64(0), asClient.lastMessageId, "Last message id should be updated")
@@ -608,7 +608,7 @@ func TestDeleteRunnerScaleSetMessage_Error(t *testing.T) {
Statistics: &actions.RunnerScaleSetStatistic{}, Statistics: &actions.RunnerScaleSetStatistic{},
} }
mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil) mockActionsClient.On("CreateMessageSession", ctx, 1, mock.Anything).Return(session, nil)
mockSessionClient.On("GetMessage", ctx, int64(0)).Return(&actions.RunnerScaleSetMessage{ mockSessionClient.On("GetMessage", ctx, int64(0), mock.Anything).Return(&actions.RunnerScaleSetMessage{
MessageId: 1, MessageId: 1,
MessageType: "test", MessageType: "test",
Body: "test", Body: "test",
@@ -623,13 +623,13 @@ func TestDeleteRunnerScaleSetMessage_Error(t *testing.T) {
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return nil return nil
}) }, 10)
assert.NoError(t, err, "Error getting initial message") assert.NoError(t, err, "Error getting initial message")
err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error { err = asClient.GetRunnerScaleSetMessage(ctx, func(msg *actions.RunnerScaleSetMessage) error {
logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body) logger.Info("Message received", "messageId", msg.MessageId, "messageType", msg.MessageType, "body", msg.Body)
return nil return nil
}) }, 10)
assert.ErrorContains(t, err, "delete message failed from refreshing client. error", "Error getting message") assert.ErrorContains(t, err, "delete message failed from refreshing client. error", "Error getting message")
assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated") assert.Equal(t, int64(1), asClient.lastMessageId, "Last message id should be updated")

View File

@@ -89,7 +89,7 @@ func (s *Service) Start() error {
s.logger.Info("service is stopped.") s.logger.Info("service is stopped.")
return nil return nil
default: default:
err := s.rsClient.GetRunnerScaleSetMessage(s.ctx, s.processMessage) err := s.rsClient.GetRunnerScaleSetMessage(s.ctx, s.processMessage, s.settings.MaxRunners)
if err != nil { if err != nil {
return fmt.Errorf("could not get and process message. %w", err) return fmt.Errorf("could not get and process message. %w", err)
} }

View File

@@ -64,7 +64,7 @@ func TestStart(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Run(func(args mock.Arguments) { cancel() }).Return(nil).Once() mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Run(func(mock.Arguments) { cancel() }).Return(nil).Once()
err = service.Start() err = service.Start()
@@ -98,7 +98,7 @@ func TestStart_ScaleToMinRunners(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
mockRsClient.On("GetRunnerScaleSetMessage", ctx, mock.Anything).Run(func(args mock.Arguments) { mockRsClient.On("GetRunnerScaleSetMessage", ctx, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
_ = service.scaleForAssignedJobCount(5) _ = service.scaleForAssignedJobCount(5)
}).Return(nil) }).Return(nil)
@@ -137,7 +137,7 @@ func TestStart_ScaleToMinRunnersFailed(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
c := mockKubeManager.On("ScaleEphemeralRunnerSet", ctx, service.settings.Namespace, service.settings.ResourceName, 5).Return(fmt.Errorf("error")).Once() c := mockKubeManager.On("ScaleEphemeralRunnerSet", ctx, service.settings.Namespace, service.settings.ResourceName, 5).Return(fmt.Errorf("error")).Once()
mockRsClient.On("GetRunnerScaleSetMessage", ctx, mock.Anything).Run(func(args mock.Arguments) { mockRsClient.On("GetRunnerScaleSetMessage", ctx, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
_ = service.scaleForAssignedJobCount(5) _ = service.scaleForAssignedJobCount(5)
}).Return(c.ReturnArguments.Get(0)) }).Return(c.ReturnArguments.Get(0))
@@ -172,8 +172,8 @@ func TestStart_GetMultipleMessages(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Return(nil).Times(5) mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Return(nil).Times(5)
mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Run(func(args mock.Arguments) { cancel() }).Return(nil).Once() mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { cancel() }).Return(nil).Once()
err = service.Start() err = service.Start()
@@ -207,8 +207,8 @@ func TestStart_ErrorOnMessage(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Return(nil).Times(2) mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Return(nil).Times(2)
mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything).Return(fmt.Errorf("error")).Once() mockRsClient.On("GetRunnerScaleSetMessage", service.ctx, mock.Anything, mock.Anything).Return(fmt.Errorf("error")).Once()
err = service.Start() err = service.Start()

View File

@@ -8,6 +8,6 @@ import (
//go:generate mockery --inpackage --name=RunnerScaleSetClient //go:generate mockery --inpackage --name=RunnerScaleSetClient
type RunnerScaleSetClient interface { type RunnerScaleSetClient interface {
GetRunnerScaleSetMessage(ctx context.Context, handler func(msg *actions.RunnerScaleSetMessage) error) error GetRunnerScaleSetMessage(ctx context.Context, handler func(msg *actions.RunnerScaleSetMessage) error, maxCapacity int) error
AcquireJobsForRunnerScaleSet(ctx context.Context, requestIds []int64) error AcquireJobsForRunnerScaleSet(ctx context.Context, requestIds []int64) error
} }

View File

@@ -29,13 +29,13 @@ func (_m *MockRunnerScaleSetClient) AcquireJobsForRunnerScaleSet(ctx context.Con
return r0 return r0
} }
// GetRunnerScaleSetMessage provides a mock function with given fields: ctx, handler // GetRunnerScaleSetMessage provides a mock function with given fields: ctx, handler, maxCapacity
func (_m *MockRunnerScaleSetClient) GetRunnerScaleSetMessage(ctx context.Context, handler func(*actions.RunnerScaleSetMessage) error) error { func (_m *MockRunnerScaleSetClient) GetRunnerScaleSetMessage(ctx context.Context, handler func(*actions.RunnerScaleSetMessage) error, maxCapacity int) error {
ret := _m.Called(ctx, handler) ret := _m.Called(ctx, handler, maxCapacity)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, func(*actions.RunnerScaleSetMessage) error) error); ok { if rf, ok := ret.Get(0).(func(context.Context, func(*actions.RunnerScaleSetMessage) error, int) error); ok {
r0 = rf(ctx, handler) r0 = rf(ctx, handler, maxCapacity)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }

View File

@@ -24,8 +24,12 @@ func newSessionClient(client actions.ActionsService, logger *logr.Logger, sessio
} }
} }
func (m *SessionRefreshingClient) GetMessage(ctx context.Context, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) { func (m *SessionRefreshingClient) GetMessage(ctx context.Context, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error) {
message, err := m.client.GetMessage(ctx, m.session.MessageQueueUrl, m.session.MessageQueueAccessToken, lastMessageId) if maxCapacity < 0 {
return nil, fmt.Errorf("maxCapacity must be greater than or equal to 0")
}
message, err := m.client.GetMessage(ctx, m.session.MessageQueueUrl, m.session.MessageQueueAccessToken, lastMessageId, maxCapacity)
if err == nil { if err == nil {
return message, nil return message, nil
} }
@@ -42,7 +46,7 @@ func (m *SessionRefreshingClient) GetMessage(ctx context.Context, lastMessageId
} }
m.session = session m.session = session
message, err = m.client.GetMessage(ctx, m.session.MessageQueueUrl, m.session.MessageQueueAccessToken, lastMessageId) message, err = m.client.GetMessage(ctx, m.session.MessageQueueUrl, m.session.MessageQueueAccessToken, lastMessageId, maxCapacity)
if err != nil { if err != nil {
return nil, fmt.Errorf("delete message failed after refresh message session. %w", err) return nil, fmt.Errorf("delete message failed after refresh message session. %w", err)
} }

View File

@@ -31,17 +31,17 @@ func TestGetMessage(t *testing.T) {
}, },
} }
mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(nil, nil).Once() mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(nil, nil).Once()
mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(&actions.RunnerScaleSetMessage{MessageId: 1}, nil).Once() mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(&actions.RunnerScaleSetMessage{MessageId: 1}, nil).Once()
client := newSessionClient(mockActionsClient, &logger, session) client := newSessionClient(mockActionsClient, &logger, session)
msg, err := client.GetMessage(ctx, 0) msg, err := client.GetMessage(ctx, 0, 10)
require.NoError(t, err, "GetMessage should not return an error") require.NoError(t, err, "GetMessage should not return an error")
assert.Nil(t, msg, "GetMessage should return nil message") assert.Nil(t, msg, "GetMessage should return nil message")
msg, err = client.GetMessage(ctx, 0) msg, err = client.GetMessage(ctx, 0, 10)
require.NoError(t, err, "GetMessage should not return an error") require.NoError(t, err, "GetMessage should not return an error")
assert.Equal(t, int64(1), msg.MessageId, "GetMessage should return a message with id 1") assert.Equal(t, int64(1), msg.MessageId, "GetMessage should return a message with id 1")
@@ -146,11 +146,11 @@ func TestGetMessage_Error(t *testing.T) {
}, },
} }
mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(nil, fmt.Errorf("error")).Once() mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(nil, fmt.Errorf("error")).Once()
client := newSessionClient(mockActionsClient, &logger, session) client := newSessionClient(mockActionsClient, &logger, session)
msg, err := client.GetMessage(ctx, 0) msg, err := client.GetMessage(ctx, 0, 10)
assert.ErrorContains(t, err, "get message failed. error", "GetMessage should return an error") assert.ErrorContains(t, err, "get message failed. error", "GetMessage should return an error")
assert.Nil(t, msg, "GetMessage should return nil message") assert.Nil(t, msg, "GetMessage should return nil message")
assert.True(t, mockActionsClient.AssertExpectations(t), "All expected calls to mockActionsClient should have been made") assert.True(t, mockActionsClient.AssertExpectations(t), "All expected calls to mockActionsClient should have been made")
@@ -227,8 +227,8 @@ func TestGetMessage_RefreshToken(t *testing.T) {
Id: 1, Id: 1,
}, },
} }
mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once()
mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, "token2", int64(0)).Return(&actions.RunnerScaleSetMessage{ mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, "token2", int64(0), 10).Return(&actions.RunnerScaleSetMessage{
MessageId: 1, MessageId: 1,
MessageType: "test", MessageType: "test",
Body: "test", Body: "test",
@@ -243,7 +243,7 @@ func TestGetMessage_RefreshToken(t *testing.T) {
}, nil).Once() }, nil).Once()
client := newSessionClient(mockActionsClient, &logger, session) client := newSessionClient(mockActionsClient, &logger, session)
msg, err := client.GetMessage(ctx, 0) msg, err := client.GetMessage(ctx, 0, 10)
assert.NoError(t, err, "Error getting message") assert.NoError(t, err, "Error getting message")
assert.Equal(t, int64(1), msg.MessageId, "message id should be updated") assert.Equal(t, int64(1), msg.MessageId, "message id should be updated")
assert.Equal(t, "token2", client.session.MessageQueueAccessToken, "Message queue access token should be updated") assert.Equal(t, "token2", client.session.MessageQueueAccessToken, "Message queue access token should be updated")
@@ -340,11 +340,11 @@ func TestGetMessage_RefreshToken_Failed(t *testing.T) {
Id: 1, Id: 1,
}, },
} }
mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0)).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() mockActionsClient.On("GetMessage", ctx, session.MessageQueueUrl, session.MessageQueueAccessToken, int64(0), 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once()
mockActionsClient.On("RefreshMessageSession", ctx, session.RunnerScaleSet.Id, session.SessionId).Return(nil, fmt.Errorf("error")) mockActionsClient.On("RefreshMessageSession", ctx, session.RunnerScaleSet.Id, session.SessionId).Return(nil, fmt.Errorf("error"))
client := newSessionClient(mockActionsClient, &logger, session) client := newSessionClient(mockActionsClient, &logger, session)
msg, err := client.GetMessage(ctx, 0) msg, err := client.GetMessage(ctx, 0, 10)
assert.ErrorContains(t, err, "refresh message session failed. error", "Error should be returned") assert.ErrorContains(t, err, "refresh message session failed. error", "Error should be returned")
assert.Nil(t, msg, "Message should be nil") assert.Nil(t, msg, "Message should be nil")
assert.Equal(t, "token", client.session.MessageQueueAccessToken, "Message queue access token should not be updated") assert.Equal(t, "token", client.session.MessageQueueAccessToken, "Message queue access token should not be updated")

View File

@@ -29,6 +29,9 @@ const (
apiVersionQueryParam = "api-version=6.0-preview" apiVersionQueryParam = "api-version=6.0-preview"
) )
// Header used to propagate capacity information to the back-end
const HeaderScaleSetMaxCapacity = "X-ScaleSetMaxCapacity"
//go:generate mockery --inpackage --name=ActionsService //go:generate mockery --inpackage --name=ActionsService
type ActionsService interface { type ActionsService interface {
GetRunnerScaleSet(ctx context.Context, runnerGroupId int, runnerScaleSetName string) (*RunnerScaleSet, error) GetRunnerScaleSet(ctx context.Context, runnerGroupId int, runnerScaleSetName string) (*RunnerScaleSet, error)
@@ -45,7 +48,7 @@ type ActionsService interface {
AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error)
GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*AcquirableJobList, error) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*AcquirableJobList, error)
GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error)
DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error
GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *RunnerScaleSetJitRunnerSetting, scaleSetId int) (*RunnerScaleSetJitRunnerConfig, error) GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *RunnerScaleSetJitRunnerSetting, scaleSetId int) (*RunnerScaleSetJitRunnerConfig, error)
@@ -104,6 +107,8 @@ type Client struct {
proxyFunc ProxyFunc proxyFunc ProxyFunc
} }
var _ ActionsService = &Client{}
type ProxyFunc func(req *http.Request) (*url.URL, error) type ProxyFunc func(req *http.Request) (*url.URL, error)
type ClientOption func(*Client) type ClientOption func(*Client)
@@ -543,7 +548,7 @@ func (c *Client) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int)
return nil return nil
} }
func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) { func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error) {
u, err := url.Parse(messageQueueUrl) u, err := url.Parse(messageQueueUrl)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -555,6 +560,10 @@ func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAc
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
} }
if maxCapacity < 0 {
return nil, fmt.Errorf("maxCapacity must be greater than or equal to 0")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -563,6 +572,7 @@ func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAc
req.Header.Set("Accept", "application/json; api-version=6.0-preview") req.Header.Set("Accept", "application/json; api-version=6.0-preview")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", messageQueueAccessToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", messageQueueAccessToken))
req.Header.Set("User-Agent", c.userAgent.String()) req.Header.Set("User-Agent", c.userAgent.String())
req.Header.Set(HeaderScaleSetMaxCapacity, strconv.Itoa(maxCapacity))
resp, err := c.Do(req) resp, err := c.Do(req)
if err != nil { if err != nil {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"net/http" "net/http"
"strconv"
"testing" "testing"
"time" "time"
@@ -35,7 +36,7 @@ func TestGetMessage(t *testing.T) {
client, err := actions.NewClient(s.configURLForOrg("my-org"), auth) client, err := actions.NewClient(s.configURLForOrg("my-org"), auth)
require.NoError(t, err) require.NoError(t, err)
got, err := client.GetMessage(ctx, s.URL, token, 0) got, err := client.GetMessage(ctx, s.URL, token, 0, 10)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, want, got) assert.Equal(t, want, got)
}) })
@@ -52,7 +53,7 @@ func TestGetMessage(t *testing.T) {
client, err := actions.NewClient(s.configURLForOrg("my-org"), auth) client, err := actions.NewClient(s.configURLForOrg("my-org"), auth)
require.NoError(t, err) require.NoError(t, err)
got, err := client.GetMessage(ctx, s.URL, token, 1) got, err := client.GetMessage(ctx, s.URL, token, 1, 10)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, want, got) assert.Equal(t, want, got)
}) })
@@ -76,7 +77,7 @@ func TestGetMessage(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
_, err = client.GetMessage(ctx, server.URL, token, 0) _, err = client.GetMessage(ctx, server.URL, token, 0, 10)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equalf(t, actualRetry, expectedRetry, "A retry was expected after the first request but got: %v", actualRetry) assert.Equalf(t, actualRetry, expectedRetry, "A retry was expected after the first request but got: %v", actualRetry)
}) })
@@ -89,7 +90,7 @@ func TestGetMessage(t *testing.T) {
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err) require.NoError(t, err)
_, err = client.GetMessage(ctx, server.URL, token, 0) _, err = client.GetMessage(ctx, server.URL, token, 0, 10)
require.NotNil(t, err) require.NotNil(t, err)
var expectedErr *actions.MessageQueueTokenExpiredError var expectedErr *actions.MessageQueueTokenExpiredError
@@ -108,7 +109,7 @@ func TestGetMessage(t *testing.T) {
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err) require.NoError(t, err)
_, err = client.GetMessage(ctx, server.URL, token, 0) _, err = client.GetMessage(ctx, server.URL, token, 0, 10)
require.NotNil(t, err) require.NotNil(t, err)
assert.Equal(t, want.Error(), err.Error()) assert.Equal(t, want.Error(), err.Error())
}) })
@@ -122,9 +123,35 @@ func TestGetMessage(t *testing.T) {
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err) require.NoError(t, err)
_, err = client.GetMessage(ctx, server.URL, token, 0) _, err = client.GetMessage(ctx, server.URL, token, 0, 10)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
t.Run("Capacity error handling", func(t *testing.T) {
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hc := r.Header.Get(actions.HeaderScaleSetMaxCapacity)
c, err := strconv.Atoi(hc)
require.NoError(t, err)
assert.GreaterOrEqual(t, c, 0)
w.WriteHeader(http.StatusBadRequest)
w.Header().Set("Content-Type", "text/plain")
}))
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetMessage(ctx, server.URL, token, 0, -1)
require.Error(t, err)
// Ensure we don't send requests with negative capacity
assert.False(t, errors.Is(err, &actions.ActionsError{}))
_, err = client.GetMessage(ctx, server.URL, token, 0, 0)
assert.Error(t, err)
var expectedErr *actions.ActionsError
assert.ErrorAs(t, err, &expectedErr)
assert.Equal(t, http.StatusBadRequest, expectedErr.StatusCode)
})
} }
func TestDeleteMessage(t *testing.T) { func TestDeleteMessage(t *testing.T) {

View File

@@ -47,7 +47,6 @@ func (e *ActionsError) IsException(target string) bool {
if ex, ok := e.Err.(*ActionsExceptionError); ok { if ex, ok := e.Err.(*ActionsExceptionError); ok {
return strings.Contains(ex.ExceptionName, target) return strings.Contains(ex.ExceptionName, target)
} }
return false return false
} }

View File

@@ -259,7 +259,7 @@ func (f *FakeClient) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int
return f.getAcquirableJobsResult.AcquirableJobList, f.getAcquirableJobsResult.err return f.getAcquirableJobsResult.AcquirableJobList, f.getAcquirableJobsResult.err
} }
func (f *FakeClient) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) { func (f *FakeClient) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error) {
return f.getMessageResult.RunnerScaleSetMessage, f.getMessageResult.err return f.getMessageResult.RunnerScaleSetMessage, f.getMessageResult.err
} }

View File

@@ -186,25 +186,25 @@ func (_m *MockActionsService) GetAcquirableJobs(ctx context.Context, runnerScale
return r0, r1 return r0, r1
} }
// GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId // GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity
func (_m *MockActionsService) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) { func (_m *MockActionsService) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error) {
ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity)
var r0 *RunnerScaleSetMessage var r0 *RunnerScaleSetMessage
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*RunnerScaleSetMessage, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) (*RunnerScaleSetMessage, error)); ok {
return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity)
} }
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *RunnerScaleSetMessage); ok { if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) *RunnerScaleSetMessage); ok {
r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*RunnerScaleSetMessage) r0 = ret.Get(0).(*RunnerScaleSetMessage)
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, int) error); ok {
r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }

View File

@@ -67,25 +67,25 @@ func (_m *MockSessionService) DeleteMessage(ctx context.Context, messageId int64
return r0 return r0
} }
// GetMessage provides a mock function with given fields: ctx, lastMessageId // GetMessage provides a mock function with given fields: ctx, lastMessageId, maxCapacity
func (_m *MockSessionService) GetMessage(ctx context.Context, lastMessageId int64) (*RunnerScaleSetMessage, error) { func (_m *MockSessionService) GetMessage(ctx context.Context, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error) {
ret := _m.Called(ctx, lastMessageId) ret := _m.Called(ctx, lastMessageId, maxCapacity)
var r0 *RunnerScaleSetMessage var r0 *RunnerScaleSetMessage
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64) (*RunnerScaleSetMessage, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, int) (*RunnerScaleSetMessage, error)); ok {
return rf(ctx, lastMessageId) return rf(ctx, lastMessageId, maxCapacity)
} }
if rf, ok := ret.Get(0).(func(context.Context, int64) *RunnerScaleSetMessage); ok { if rf, ok := ret.Get(0).(func(context.Context, int64, int) *RunnerScaleSetMessage); ok {
r0 = rf(ctx, lastMessageId) r0 = rf(ctx, lastMessageId, maxCapacity)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*RunnerScaleSetMessage) r0 = ret.Get(0).(*RunnerScaleSetMessage)
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, int64, int) error); ok {
r1 = rf(ctx, lastMessageId) r1 = rf(ctx, lastMessageId, maxCapacity)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }

View File

@@ -7,7 +7,7 @@ import (
//go:generate mockery --inpackage --name=SessionService //go:generate mockery --inpackage --name=SessionService
type SessionService interface { type SessionService interface {
GetMessage(ctx context.Context, lastMessageId int64) (*RunnerScaleSetMessage, error) GetMessage(ctx context.Context, lastMessageId int64, maxCapacity int) (*RunnerScaleSetMessage, error)
DeleteMessage(ctx context.Context, messageId int64) error DeleteMessage(ctx context.Context, messageId int64) error
AcquireJobs(ctx context.Context, requestIds []int64) ([]int64, error) AcquireJobs(ctx context.Context, requestIds []int64) ([]int64, error)
io.Closer io.Closer