Skip to content

Commit

Permalink
Add default credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonTian authored and peze committed Aug 20, 2024
1 parent e98371f commit 36e3cfd
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 0 deletions.
94 changes: 94 additions & 0 deletions credentials/internal/providers/default.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package providers

import (
"fmt"
"os"
"strings"
)

type DefaultCredentialsProvider struct {
providerChain []CredentialsProvider
lastUsedProvider CredentialsProvider
}

func NewDefaultCredentialsProvider() (provider *DefaultCredentialsProvider) {
providers := []CredentialsProvider{}

// Add static ak or sts credentials provider
envProvider, err := NewEnvironmentVariableCredentialsProviderBuilder().Build()
if err == nil {
providers = append(providers, envProvider)
}

// oidc check
oidcProvider, err := NewOIDCCredentialsProviderBuilder().Build()
if err == nil {
providers = append(providers, oidcProvider)
}

// cli credentials provider
providers = append(providers, NewCLIProfileCredentialsProviderBuilder().Build())

// profile credentials provider
// providers = append(providers)
providers = append(providers, NewProfileCredentialsProviderBuilder().Build())

// Add IMDS
if os.Getenv("ALIBABA_CLOUD_ECS_METADATA") != "" {
ecsRamRoleProvider, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName(os.Getenv("ALIBABA_CLOUD_ECS_METADATA")).Build()
if err == nil {
providers = append(providers, ecsRamRoleProvider)
}
}

// TODO: ALIBABA_CLOUD_CREDENTIALS_URI check

return &DefaultCredentialsProvider{
providerChain: providers,
}
}

func (provider *DefaultCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
if provider.lastUsedProvider != nil {
inner, err1 := provider.lastUsedProvider.GetCredentials()
if err1 != nil {
return
}

cc = &Credentials{
AccessKeyId: inner.AccessKeyId,
AccessKeySecret: inner.AccessKeySecret,
SecurityToken: inner.SecurityToken,
ProviderName: fmt.Sprintf("%s/%s", provider.GetProviderName(), provider.lastUsedProvider.GetProviderName()),
}
return
}

errors := []string{}
for _, p := range provider.providerChain {
provider.lastUsedProvider = p
inner, errInLoop := p.GetCredentials()
if errInLoop != nil {
errors = append(errors, errInLoop.Error())
// 如果有错误,进入下一个获取过程
continue
}

if inner != nil {
cc = &Credentials{
AccessKeyId: inner.AccessKeyId,
AccessKeySecret: inner.AccessKeySecret,
SecurityToken: inner.SecurityToken,
ProviderName: fmt.Sprintf("%s/%s", provider.GetProviderName(), p.GetProviderName()),
}
return
}
}

err = fmt.Errorf("unable to get credentials from any of the providers in the chain: %s", strings.Join(errors, ", "))
return
}

func (provider *DefaultCredentialsProvider) GetProviderName() string {
return "default"
}
102 changes: 102 additions & 0 deletions credentials/internal/providers/default_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package providers

import (
"os"
"testing"

"github.com/aliyun/credentials-go/credentials/internal/utils"
"github.com/stretchr/testify/assert"
)

func TestDefaultCredentialsProvider(t *testing.T) {
provider := NewDefaultCredentialsProvider()
assert.NotNil(t, provider)
assert.Len(t, provider.providerChain, 3)
_, ok := provider.providerChain[0].(*EnvironmentVariableCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[1].(*CLIProfileCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[2].(*ProfileCredentialsProvider)
assert.True(t, ok)

// Add oidc provider
rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE",
"ALIBABA_CLOUD_OIDC_PROVIDER_ARN",
"ALIBABA_CLOUD_ROLE_ARN",
"ALIBABA_CLOUD_ECS_METADATA")

defer rollback()
os.Setenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "/path/to/oidc.token")
os.Setenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "oidcproviderarn")
os.Setenv("ALIBABA_CLOUD_ROLE_ARN", "rolearn")

provider = NewDefaultCredentialsProvider()
assert.NotNil(t, provider)
assert.Len(t, provider.providerChain, 4)
_, ok = provider.providerChain[0].(*EnvironmentVariableCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[1].(*OIDCCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[2].(*CLIProfileCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[3].(*ProfileCredentialsProvider)
assert.True(t, ok)

// Add ecs ram role
os.Setenv("ALIBABA_CLOUD_ECS_METADATA", "rolename")
provider = NewDefaultCredentialsProvider()
assert.NotNil(t, provider)
assert.Len(t, provider.providerChain, 5)
_, ok = provider.providerChain[0].(*EnvironmentVariableCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[1].(*OIDCCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[2].(*CLIProfileCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[3].(*ProfileCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[4].(*ECSRAMRoleCredentialsProvider)
assert.True(t, ok)
}

func TestDefaultCredentialsProvider_GetCredentials(t *testing.T) {
rollback := utils.Memory("ALIBABA_CLOUD_ACCESS_KEY_ID",
"ALIBABA_CLOUD_ACCESS_KEY_SECRET",
"ALIBABA_CLOUD_SECURITY_TOKEN")

defer func() {
getHomePath = utils.GetHomePath
rollback()
}()

// testcase: empty home
getHomePath = func() string {
return ""
}

provider := NewDefaultCredentialsProvider()
assert.Len(t, provider.providerChain, 3)
_, err := provider.GetCredentials()
assert.EqualError(t, err, "unable to get credentials from any of the providers in the chain: unable to get credentials from enviroment variables, Access key ID must be specified via environment variable (ALIBABA_CLOUD_ACCESS_KEY_ID), cannot found home dir, cannot found home dir")

os.Setenv("ALIBABA_CLOUD_ACCESS_KEY_ID", "akid")
os.Setenv("ALIBABA_CLOUD_ACCESS_KEY_SECRET", "aksecret")
provider = NewDefaultCredentialsProvider()
assert.Len(t, provider.providerChain, 3)
cc, err := provider.GetCredentials()
assert.Nil(t, err)
assert.Equal(t, &Credentials{AccessKeyId: "akid", AccessKeySecret: "aksecret", SecurityToken: "", ProviderName: "default/env"}, cc)
// get again
cc, err = provider.GetCredentials()
assert.Nil(t, err)
assert.Equal(t, &Credentials{AccessKeyId: "akid", AccessKeySecret: "aksecret", SecurityToken: "", ProviderName: "default/env"}, cc)
}

0 comments on commit 36e3cfd

Please sign in to comment.