From f65c223155bfedaf5ab275731d46c7078bd93800 Mon Sep 17 00:00:00 2001 From: Jackson Tian Date: Tue, 20 Aug 2024 17:53:32 +0800 Subject: [PATCH] Add default credentials provider --- credentials/internal/providers/default.go | 94 ++++++++++++++++ .../internal/providers/default_test.go | 102 ++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 credentials/internal/providers/default.go create mode 100644 credentials/internal/providers/default_test.go diff --git a/credentials/internal/providers/default.go b/credentials/internal/providers/default.go new file mode 100644 index 0000000..8321d10 --- /dev/null +++ b/credentials/internal/providers/default.go @@ -0,0 +1,94 @@ +package providers + +import ( + "fmt" + "os" + "strings" +) + +type DefaultCredentialsProvider struct { + providerChain []CredentialsProvider + lastUsedProvider CredentialsProvider +} + +func NewDefaultCredentialsProvider() (provider *DefaultCredentialsProvider) { + providers := []CredentialsProvider{} + + // Add static ak or sts credentials provider + envProvider, err := NewEnvironmentVariableCredentialsProviderBuilder().Build() + if err == nil { + providers = append(providers, envProvider) + } + + // oidc check + oidcProvider, err := NewOIDCCredentialsProviderBuilder().Build() + if err == nil { + providers = append(providers, oidcProvider) + } + + // cli credentials provider + providers = append(providers, NewCLIProfileCredentialsProviderBuilder().Build()) + + // profile credentials provider + // providers = append(providers) + providers = append(providers, NewProfileCredentialsProviderBuilder().Build()) + + // Add IMDS + if os.Getenv("ALIBABA_CLOUD_ECS_METADATA") != "" { + ecsRamRoleProvider, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName(os.Getenv("ALIBABA_CLOUD_ECS_METADATA")).Build() + if err == nil { + providers = append(providers, ecsRamRoleProvider) + } + } + + // TODO: ALIBABA_CLOUD_CREDENTIALS_URI check + + return &DefaultCredentialsProvider{ + providerChain: providers, + } +} + +func (provider *DefaultCredentialsProvider) GetCredentials() (cc *Credentials, err error) { + if provider.lastUsedProvider != nil { + inner, err1 := provider.lastUsedProvider.GetCredentials() + if err1 != nil { + return + } + + cc = &Credentials{ + AccessKeyId: inner.AccessKeyId, + AccessKeySecret: inner.AccessKeySecret, + SecurityToken: inner.SecurityToken, + ProviderName: fmt.Sprintf("%s/%s", provider.GetProviderName(), provider.lastUsedProvider.GetProviderName()), + } + return + } + + errors := []string{} + for _, p := range provider.providerChain { + provider.lastUsedProvider = p + inner, errInLoop := p.GetCredentials() + if errInLoop != nil { + errors = append(errors, errInLoop.Error()) + // 如果有错误,进入下一个获取过程 + continue + } + + if inner != nil { + cc = &Credentials{ + AccessKeyId: inner.AccessKeyId, + AccessKeySecret: inner.AccessKeySecret, + SecurityToken: inner.SecurityToken, + ProviderName: fmt.Sprintf("%s/%s", provider.GetProviderName(), p.GetProviderName()), + } + return + } + } + + err = fmt.Errorf("unable to get credentials from any of the providers in the chain: %s", strings.Join(errors, ", ")) + return +} + +func (provider *DefaultCredentialsProvider) GetProviderName() string { + return "default" +} diff --git a/credentials/internal/providers/default_test.go b/credentials/internal/providers/default_test.go new file mode 100644 index 0000000..989c128 --- /dev/null +++ b/credentials/internal/providers/default_test.go @@ -0,0 +1,102 @@ +package providers + +import ( + "os" + "testing" + + "github.com/aliyun/credentials-go/credentials/internal/utils" + "github.com/stretchr/testify/assert" +) + +func TestDefaultCredentialsProvider(t *testing.T) { + provider := NewDefaultCredentialsProvider() + assert.NotNil(t, provider) + assert.Len(t, provider.providerChain, 3) + _, ok := provider.providerChain[0].(*EnvironmentVariableCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[1].(*CLIProfileCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[2].(*ProfileCredentialsProvider) + assert.True(t, ok) + + // Add oidc provider + rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", + "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", + "ALIBABA_CLOUD_ROLE_ARN", + "ALIBABA_CLOUD_ECS_METADATA") + + defer rollback() + os.Setenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "/path/to/oidc.token") + os.Setenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "oidcproviderarn") + os.Setenv("ALIBABA_CLOUD_ROLE_ARN", "rolearn") + + provider = NewDefaultCredentialsProvider() + assert.NotNil(t, provider) + assert.Len(t, provider.providerChain, 4) + _, ok = provider.providerChain[0].(*EnvironmentVariableCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[1].(*OIDCCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[2].(*CLIProfileCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[3].(*ProfileCredentialsProvider) + assert.True(t, ok) + + // Add ecs ram role + os.Setenv("ALIBABA_CLOUD_ECS_METADATA", "rolename") + provider = NewDefaultCredentialsProvider() + assert.NotNil(t, provider) + assert.Len(t, provider.providerChain, 5) + _, ok = provider.providerChain[0].(*EnvironmentVariableCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[1].(*OIDCCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[2].(*CLIProfileCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[3].(*ProfileCredentialsProvider) + assert.True(t, ok) + + _, ok = provider.providerChain[4].(*ECSRAMRoleCredentialsProvider) + assert.True(t, ok) +} + +func TestDefaultCredentialsProvider_GetCredentials(t *testing.T) { + rollback := utils.Memory("ALIBABA_CLOUD_ACCESS_KEY_ID", + "ALIBABA_CLOUD_ACCESS_KEY_SECRET", + "ALIBABA_CLOUD_SECURITY_TOKEN") + + defer func() { + getHomePath = utils.GetHomePath + rollback() + }() + + // testcase: empty home + getHomePath = func() string { + return "" + } + + provider := NewDefaultCredentialsProvider() + assert.Len(t, provider.providerChain, 3) + _, err := provider.GetCredentials() + assert.EqualError(t, err, "unable to get credentials from any of the providers in the chain: unable to get credentials from enviroment variables, Access key ID must be specified via environment variable (ALIBABA_CLOUD_ACCESS_KEY_ID), cannot found home dir, cannot found home dir") + + os.Setenv("ALIBABA_CLOUD_ACCESS_KEY_ID", "akid") + os.Setenv("ALIBABA_CLOUD_ACCESS_KEY_SECRET", "aksecret") + provider = NewDefaultCredentialsProvider() + assert.Len(t, provider.providerChain, 3) + cc, err := provider.GetCredentials() + assert.Nil(t, err) + assert.Equal(t, &Credentials{AccessKeyId: "akid", AccessKeySecret: "aksecret", SecurityToken: "", ProviderName: "default/env"}, cc) + // get again + cc, err = provider.GetCredentials() + assert.Nil(t, err) + assert.Equal(t, &Credentials{AccessKeyId: "akid", AccessKeySecret: "aksecret", SecurityToken: "", ProviderName: "default/env"}, cc) +}