From 47c2eab8869f487060c83d92bbc79c9f04178ca7 Mon Sep 17 00:00:00 2001 From: nanhe Date: Mon, 28 Oct 2024 16:43:45 +0800 Subject: [PATCH] feat: support env ALIBABA_CLOUD_STS_REGION for sts endpoint --- credentials/internal/providers/oidc.go | 2 ++ credentials/internal/providers/oidc_test.go | 10 ++++++- .../internal/providers/ram_role_arn.go | 3 ++ .../internal/providers/ram_role_arn_test.go | 29 +++++++++++++++---- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/credentials/internal/providers/oidc.go b/credentials/internal/providers/oidc.go index 7ff0783..f140c25 100644 --- a/credentials/internal/providers/oidc.go +++ b/credentials/internal/providers/oidc.go @@ -128,6 +128,8 @@ func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvi if b.provider.stsEndpoint == "" { if b.provider.stsRegionId != "" { b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", b.provider.stsRegionId) + } else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" { + b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", region) } else { b.provider.stsEndpoint = "sts.aliyuncs.com" } diff --git a/credentials/internal/providers/oidc_test.go b/credentials/internal/providers/oidc_test.go index 162ad1f..a247d04 100644 --- a/credentials/internal/providers/oidc_test.go +++ b/credentials/internal/providers/oidc_test.go @@ -36,7 +36,7 @@ func TestOIDCCredentialsProviderGetCredentialsWithError(t *testing.T) { } func TestNewOIDCCredentialsProvider(t *testing.T) { - rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "ALIBABA_CLOUD_ROLE_ARN") + rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "ALIBABA_CLOUD_ROLE_ARN", "ALIBABA_CLOUD_STS_REGION") defer func() { rollback() }() @@ -89,6 +89,14 @@ func TestNewOIDCCredentialsProvider(t *testing.T) { assert.Equal(t, "role_arn_from_env", p.roleArn) // sts endpoint: default assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint) + + // sts endpoint: with sts endpoint env + os.Setenv("ALIBABA_CLOUD_STS_REGION", "cn-hangzhou") + p, err = NewOIDCCredentialsProviderBuilder(). + Build() + assert.Nil(t, err) + assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint) + // sts endpoint: with sts endpoint p, err = NewOIDCCredentialsProviderBuilder(). WithSTSEndpoint("sts.cn-shanghai.aliyuncs.com"). diff --git a/credentials/internal/providers/ram_role_arn.go b/credentials/internal/providers/ram_role_arn.go index 119efeb..ada9910 100644 --- a/credentials/internal/providers/ram_role_arn.go +++ b/credentials/internal/providers/ram_role_arn.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/url" + "os" "strconv" "strings" "time" @@ -146,6 +147,8 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleA if builder.provider.stsEndpoint == "" { if builder.provider.stsRegionId != "" { builder.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", builder.provider.stsRegionId) + } else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" { + builder.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", region) } else { builder.provider.stsEndpoint = "sts.aliyuncs.com" } diff --git a/credentials/internal/providers/ram_role_arn_test.go b/credentials/internal/providers/ram_role_arn_test.go index 9551e91..bc87224 100644 --- a/credentials/internal/providers/ram_role_arn_test.go +++ b/credentials/internal/providers/ram_role_arn_test.go @@ -2,15 +2,21 @@ package providers import ( "errors" + "os" "strings" "testing" "time" httputil "github.com/aliyun/credentials-go/credentials/internal/http" + "github.com/aliyun/credentials-go/credentials/internal/utils" "github.com/stretchr/testify/assert" ) func TestNewRAMRoleARNCredentialsProvider(t *testing.T) { + rollback := utils.Memory("ALIBABA_CLOUD_STS_REGION") + defer func() { + rollback() + }() // case 1: no credentials provider _, err := NewRAMRoleARNCredentialsProviderBuilder(). Build() @@ -70,11 +76,10 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) { // sts endpoint with sts region assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint) - // sts endpoint with sts endpoint + // default sts endpoint p, err = NewRAMRoleARNCredentialsProviderBuilder(). WithCredentialsProvider(akProvider). WithRoleArn("roleArn"). - WithStsEndpoint("sts.cn-shanghai.aliyuncs.com"). WithPolicy("policy"). WithExternalId("externalId"). WithRoleSessionName("rsn"). @@ -87,9 +92,10 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) { assert.Equal(t, "externalId", p.externalId) assert.Equal(t, "", p.stsRegionId) assert.Equal(t, 1000, p.durationSeconds) - assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint) + assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint) - // default sts endpoint + // sts endpoint with env + os.Setenv("ALIBABA_CLOUD_STS_REGION", "cn-hangzhou") p, err = NewRAMRoleARNCredentialsProviderBuilder(). WithCredentialsProvider(akProvider). WithRoleArn("roleArn"). @@ -99,13 +105,26 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) { WithDurationSeconds(1000). Build() assert.Nil(t, err) + assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint) + + // sts endpoint with sts endpoint + p, err = NewRAMRoleARNCredentialsProviderBuilder(). + WithCredentialsProvider(akProvider). + WithRoleArn("roleArn"). + WithStsEndpoint("sts.cn-shanghai.aliyuncs.com"). + WithPolicy("policy"). + WithExternalId("externalId"). + WithRoleSessionName("rsn"). + WithDurationSeconds(1000). + Build() + assert.Nil(t, err) assert.Equal(t, "rsn", p.roleSessionName) assert.Equal(t, "roleArn", p.roleArn) assert.Equal(t, "policy", p.policy) assert.Equal(t, "externalId", p.externalId) assert.Equal(t, "", p.stsRegionId) assert.Equal(t, 1000, p.durationSeconds) - assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint) + assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint) } func TestRAMRoleARNCredentialsProvider_getCredentials(t *testing.T) {