Add support for enterprise runners (#290)

* Add support for enterprise runners

* update docs
This commit is contained in:
Jesse Haka
2021-02-05 02:31:06 +02:00
committed by GitHub
parent 831db9ee2a
commit 28e80a2d28
16 changed files with 207 additions and 102 deletions

View File

@@ -78,7 +78,7 @@ func (c *Config) NewClient() (*Client, error) {
}
// GetRegistrationToken returns a registration token tied with the name of repository and runner.
func (c *Client) GetRegistrationToken(ctx context.Context, org, repo, name string) (*github.RegistrationToken, error) {
func (c *Client) GetRegistrationToken(ctx context.Context, enterprise, org, repo, name string) (*github.RegistrationToken, error) {
c.mu.Lock()
defer c.mu.Unlock()
@@ -89,13 +89,13 @@ func (c *Client) GetRegistrationToken(ctx context.Context, org, repo, name strin
return rt, nil
}
owner, repo, err := getOwnerAndRepo(org, repo)
enterprise, owner, repo, err := getEnterpriseOrganisationAndRepo(enterprise, org, repo)
if err != nil {
return rt, err
}
rt, res, err := c.createRegistrationToken(ctx, owner, repo)
rt, res, err := c.createRegistrationToken(ctx, enterprise, owner, repo)
if err != nil {
return nil, fmt.Errorf("failed to create registration token: %v", err)
@@ -114,14 +114,14 @@ func (c *Client) GetRegistrationToken(ctx context.Context, org, repo, name strin
}
// RemoveRunner removes a runner with specified runner ID from repository.
func (c *Client) RemoveRunner(ctx context.Context, org, repo string, runnerID int64) error {
owner, repo, err := getOwnerAndRepo(org, repo)
func (c *Client) RemoveRunner(ctx context.Context, enterprise, org, repo string, runnerID int64) error {
enterprise, owner, repo, err := getEnterpriseOrganisationAndRepo(enterprise, org, repo)
if err != nil {
return err
}
res, err := c.removeRunner(ctx, owner, repo, runnerID)
res, err := c.removeRunner(ctx, enterprise, owner, repo, runnerID)
if err != nil {
return fmt.Errorf("failed to remove runner: %v", err)
@@ -135,8 +135,8 @@ func (c *Client) RemoveRunner(ctx context.Context, org, repo string, runnerID in
}
// ListRunners returns a list of runners of specified owner/repository name.
func (c *Client) ListRunners(ctx context.Context, org, repo string) ([]*github.Runner, error) {
owner, repo, err := getOwnerAndRepo(org, repo)
func (c *Client) ListRunners(ctx context.Context, enterprise, org, repo string) ([]*github.Runner, error) {
enterprise, owner, repo, err := getEnterpriseOrganisationAndRepo(enterprise, org, repo)
if err != nil {
return nil, err
@@ -146,7 +146,7 @@ func (c *Client) ListRunners(ctx context.Context, org, repo string) ([]*github.R
opts := github.ListOptions{PerPage: 10}
for {
list, res, err := c.listRunners(ctx, owner, repo, &opts)
list, res, err := c.listRunners(ctx, enterprise, owner, repo, &opts)
if err != nil {
return runners, fmt.Errorf("failed to list runners: %v", err)
@@ -174,42 +174,52 @@ func (c *Client) cleanup() {
}
}
// wrappers for github functions (switch between organization/repository mode)
// wrappers for github functions (switch between enterprise/organization/repository mode)
// so the calling functions don't need to switch and their code is a bit cleaner
func (c *Client) createRegistrationToken(ctx context.Context, owner, repo string) (*github.RegistrationToken, *github.Response, error) {
func (c *Client) createRegistrationToken(ctx context.Context, enterprise, org, repo string) (*github.RegistrationToken, *github.Response, error) {
if len(repo) > 0 {
return c.Client.Actions.CreateRegistrationToken(ctx, owner, repo)
}
return c.Client.Actions.CreateOrganizationRegistrationToken(ctx, owner)
}
func (c *Client) removeRunner(ctx context.Context, owner, repo string, runnerID int64) (*github.Response, error) {
if len(repo) > 0 {
return c.Client.Actions.RemoveRunner(ctx, owner, repo, runnerID)
}
return c.Client.Actions.RemoveOrganizationRunner(ctx, owner, runnerID)
}
func (c *Client) listRunners(ctx context.Context, owner, repo string, opts *github.ListOptions) (*github.Runners, *github.Response, error) {
if len(repo) > 0 {
return c.Client.Actions.ListRunners(ctx, owner, repo, opts)
}
return c.Client.Actions.ListOrganizationRunners(ctx, owner, opts)
}
// Validates owner and repo arguments. Both are optional, but at least one should be specified
func getOwnerAndRepo(org, repo string) (string, string, error) {
if len(repo) > 0 {
return splitOwnerAndRepo(repo)
return c.Client.Actions.CreateRegistrationToken(ctx, org, repo)
}
if len(org) > 0 {
return org, "", nil
return c.Client.Actions.CreateOrganizationRegistrationToken(ctx, org)
}
return "", "", fmt.Errorf("organization and repository are both empty")
return c.Client.Enterprise.CreateRegistrationToken(ctx, enterprise)
}
func (c *Client) removeRunner(ctx context.Context, enterprise, org, repo string, runnerID int64) (*github.Response, error) {
if len(repo) > 0 {
return c.Client.Actions.RemoveRunner(ctx, org, repo, runnerID)
}
if len(org) > 0 {
return c.Client.Actions.RemoveOrganizationRunner(ctx, org, runnerID)
}
return c.Client.Enterprise.RemoveRunner(ctx, enterprise, runnerID)
}
func (c *Client) listRunners(ctx context.Context, enterprise, org, repo string, opts *github.ListOptions) (*github.Runners, *github.Response, error) {
if len(repo) > 0 {
return c.Client.Actions.ListRunners(ctx, org, repo, opts)
}
if len(org) > 0 {
return c.Client.Actions.ListOrganizationRunners(ctx, org, opts)
}
return c.Client.Enterprise.ListRunners(ctx, enterprise, opts)
}
// Validates enterprise, organisation and repo arguments. Both are optional, but at least one should be specified
func getEnterpriseOrganisationAndRepo(enterprise, org, repo string) (string, string, string, error) {
if len(repo) > 0 {
owner, repository, err := splitOwnerAndRepo(repo)
return "", owner, repository, err
}
if len(org) > 0 {
return "", org, "", nil
}
if len(enterprise) > 0 {
return enterprise, "", "", nil
}
return "", "", "", fmt.Errorf("enterprise, organization and repository are all empty")
}
func getRegistrationKey(org, repo string) string {

View File

@@ -39,22 +39,26 @@ func TestMain(m *testing.M) {
func TestGetRegistrationToken(t *testing.T) {
tests := []struct {
org string
repo string
token string
err bool
enterprise string
org string
repo string
token string
err bool
}{
{org: "", repo: "test/valid", token: fake.RegistrationToken, err: false},
{org: "", repo: "test/invalid", token: "", err: true},
{org: "", repo: "test/error", token: "", err: true},
{org: "test", repo: "", token: fake.RegistrationToken, err: false},
{org: "invalid", repo: "", token: "", err: true},
{org: "error", repo: "", token: "", err: true},
{enterprise: "", org: "", repo: "test/valid", token: fake.RegistrationToken, err: false},
{enterprise: "", org: "", repo: "test/invalid", token: "", err: true},
{enterprise: "", org: "", repo: "test/error", token: "", err: true},
{enterprise: "", org: "test", repo: "", token: fake.RegistrationToken, err: false},
{enterprise: "", org: "invalid", repo: "", token: "", err: true},
{enterprise: "", org: "error", repo: "", token: "", err: true},
{enterprise: "test", org: "", repo: "", token: fake.RegistrationToken, err: false},
{enterprise: "invalid", org: "", repo: "", token: "", err: true},
{enterprise: "error", org: "", repo: "", token: "", err: true},
}
client := newTestClient()
for i, tt := range tests {
rt, err := client.GetRegistrationToken(context.Background(), tt.org, tt.repo, "test")
rt, err := client.GetRegistrationToken(context.Background(), tt.enterprise, tt.org, tt.repo, "test")
if !tt.err && err != nil {
t.Errorf("[%d] unexpected error: %v", i, err)
}
@@ -66,22 +70,26 @@ func TestGetRegistrationToken(t *testing.T) {
func TestListRunners(t *testing.T) {
tests := []struct {
org string
repo string
length int
err bool
enterprise string
org string
repo string
length int
err bool
}{
{org: "", repo: "test/valid", length: 2, err: false},
{org: "", repo: "test/invalid", length: 0, err: true},
{org: "", repo: "test/error", length: 0, err: true},
{org: "test", repo: "", length: 2, err: false},
{org: "invalid", repo: "", length: 0, err: true},
{org: "error", repo: "", length: 0, err: true},
{enterprise: "", org: "", repo: "test/valid", length: 2, err: false},
{enterprise: "", org: "", repo: "test/invalid", length: 0, err: true},
{enterprise: "", org: "", repo: "test/error", length: 0, err: true},
{enterprise: "", org: "test", repo: "", length: 2, err: false},
{enterprise: "", org: "invalid", repo: "", length: 0, err: true},
{enterprise: "", org: "error", repo: "", length: 0, err: true},
{enterprise: "test", org: "", repo: "", length: 2, err: false},
{enterprise: "invalid", org: "", repo: "", length: 0, err: true},
{enterprise: "error", org: "", repo: "", length: 0, err: true},
}
client := newTestClient()
for i, tt := range tests {
runners, err := client.ListRunners(context.Background(), tt.org, tt.repo)
runners, err := client.ListRunners(context.Background(), tt.enterprise, tt.org, tt.repo)
if !tt.err && err != nil {
t.Errorf("[%d] unexpected error: %v", i, err)
}
@@ -93,21 +101,25 @@ func TestListRunners(t *testing.T) {
func TestRemoveRunner(t *testing.T) {
tests := []struct {
org string
repo string
err bool
enterprise string
org string
repo string
err bool
}{
{org: "", repo: "test/valid", err: false},
{org: "", repo: "test/invalid", err: true},
{org: "", repo: "test/error", err: true},
{org: "test", repo: "", err: false},
{org: "invalid", repo: "", err: true},
{org: "error", repo: "", err: true},
{enterprise: "", org: "", repo: "test/valid", err: false},
{enterprise: "", org: "", repo: "test/invalid", err: true},
{enterprise: "", org: "", repo: "test/error", err: true},
{enterprise: "", org: "test", repo: "", err: false},
{enterprise: "", org: "invalid", repo: "", err: true},
{enterprise: "", org: "error", repo: "", err: true},
{enterprise: "test", org: "", repo: "", err: false},
{enterprise: "invalid", org: "", repo: "", err: true},
{enterprise: "error", org: "", repo: "", err: true},
}
client := newTestClient()
for i, tt := range tests {
err := client.RemoveRunner(context.Background(), tt.org, tt.repo, int64(1))
err := client.RemoveRunner(context.Background(), tt.enterprise, tt.org, tt.repo, int64(1))
if !tt.err && err != nil {
t.Errorf("[%d] unexpected error: %v", i, err)
}