diff --git a/credentials/credential.go b/credentials/credential.go index 5ffe6a1..2739c6a 100644 --- a/credentials/credential.go +++ b/credentials/credential.go @@ -49,6 +49,7 @@ type Config struct { PublicKeyId *string `json:"public_key_id"` RoleName *string `json:"role_name"` EnableIMDSv2 *bool `json:"enable_imds_v2"` + DisableIMDSv1 *bool `json:"disable_imds_v1"` MetadataTokenDuration *int `json:"metadata_token_duration"` SessionExpiration *int `json:"session_expiration"` PrivateKeyFile *string `json:"private_key_file"` @@ -248,8 +249,7 @@ func NewCredential(config *Config) (credential Credential, err error) { case "ecs_ram_role": provider, err := providers.NewECSRAMRoleCredentialsProviderBuilder(). WithRoleName(tea.StringValue(config.RoleName)). - WithEnableIMDSv2(tea.BoolValue(config.EnableIMDSv2)). - WithMetadataTokenDurationSeconds(tea.IntValue(config.MetadataTokenDuration)). + WithDisableIMDSv1(tea.BoolValue(config.DisableIMDSv1)). Build() if err != nil { diff --git a/credentials/credential_test.go b/credentials/credential_test.go index 6243914..d602183 100644 --- a/credentials/credential_test.go +++ b/credentials/credential_test.go @@ -15,8 +15,8 @@ this is privatekey` func TestConfig(t *testing.T) { config := new(Config) - assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String()) - assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString()) + assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String()) + assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString()) config.SetSTSEndpoint("sts.cn-hangzhou.aliyuncs.com") assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", *config.STSEndpoint) diff --git a/credentials/internal/providers/ecs_ram_role.go b/credentials/internal/providers/ecs_ram_role.go index ac5b036..466031d 100644 --- a/credentials/internal/providers/ecs_ram_role.go +++ b/credentials/internal/providers/ecs_ram_role.go @@ -2,7 +2,6 @@ package providers import ( "encoding/json" - "errors" "fmt" "os" "strconv" @@ -13,9 +12,8 @@ import ( ) type ECSRAMRoleCredentialsProvider struct { - roleName string - metadataTokenDurationSeconds int - enableIMDSv2 bool + roleName string + disableIMDSv1 bool // for sts session *sessionCredentials expirationTimestamp int64 @@ -27,43 +25,31 @@ type ECSRAMRoleCredentialsProviderBuilder struct { func NewECSRAMRoleCredentialsProviderBuilder() *ECSRAMRoleCredentialsProviderBuilder { return &ECSRAMRoleCredentialsProviderBuilder{ - provider: &ECSRAMRoleCredentialsProvider{ - // TBD: 默认启用 IMDS v2 - // enableIMDSv2: os.Getenv("ALIBABA_CLOUD_IMDSV2_DISABLED") != "true", // 默认启用 v2 - }, + provider: &ECSRAMRoleCredentialsProvider{}, } } -func (builder *ECSRAMRoleCredentialsProviderBuilder) WithMetadataTokenDurationSeconds(metadataTokenDurationSeconds int) *ECSRAMRoleCredentialsProviderBuilder { - builder.provider.metadataTokenDurationSeconds = metadataTokenDurationSeconds - return builder -} - func (builder *ECSRAMRoleCredentialsProviderBuilder) WithRoleName(roleName string) *ECSRAMRoleCredentialsProviderBuilder { builder.provider.roleName = roleName return builder } -func (builder *ECSRAMRoleCredentialsProviderBuilder) WithEnableIMDSv2(enableIMDSv2 bool) *ECSRAMRoleCredentialsProviderBuilder { - builder.provider.enableIMDSv2 = enableIMDSv2 +func (builder *ECSRAMRoleCredentialsProviderBuilder) WithDisableIMDSv1(disableIMDSv1 bool) *ECSRAMRoleCredentialsProviderBuilder { + builder.provider.disableIMDSv1 = disableIMDSv1 return builder } const defaultMetadataTokenDuration = 21600 // 6 hours func (builder *ECSRAMRoleCredentialsProviderBuilder) Build() (provider *ECSRAMRoleCredentialsProvider, err error) { + // 设置 roleName 默认值 if builder.provider.roleName == "" { builder.provider.roleName = os.Getenv("ALIBABA_CLOUD_ECS_METADATA") } - if builder.provider.metadataTokenDurationSeconds == 0 { - builder.provider.metadataTokenDurationSeconds = defaultMetadataTokenDuration - } - - if builder.provider.metadataTokenDurationSeconds < 1 || builder.provider.metadataTokenDurationSeconds > 21600 { - err = errors.New("the metadata token duration seconds must be 1-21600") - return + if !builder.provider.disableIMDSv1 { + builder.provider.disableIMDSv1 = os.Getenv("ALIBABA_CLOUD_IMDSV1_DISABLE") == "true" } provider = builder.provider @@ -98,11 +84,11 @@ func (provider *ECSRAMRoleCredentialsProvider) getRoleName() (roleName string, e Headers: map[string]string{}, } - if provider.enableIMDSv2 { - metadataToken, err := provider.getMetadataToken() - if err != nil { - return "", err - } + metadataToken, err := provider.getMetadataToken() + if err != nil { + return "", err + } + if metadataToken != "" { req.Headers["x-aliyun-ecs-metadata-token"] = metadataToken } @@ -140,11 +126,11 @@ func (provider *ECSRAMRoleCredentialsProvider) getCredentials() (session *sessio Headers: map[string]string{}, } - if provider.enableIMDSv2 { - metadataToken, err := provider.getMetadataToken() - if err != nil { - return nil, err - } + metadataToken, err := provider.getMetadataToken() + if err != nil { + return nil, err + } + if metadataToken != "" { req.Headers["x-aliyun-ecs-metadata-token"] = metadataToken } @@ -221,14 +207,22 @@ func (provider *ECSRAMRoleCredentialsProvider) getMetadataToken() (metadataToken Host: "100.100.100.200", Path: "/latest/api/token", Headers: map[string]string{ - "X-aliyun-ecs-metadata-token-ttl-seconds": strconv.Itoa(provider.metadataTokenDurationSeconds), + "X-aliyun-ecs-metadata-token-ttl-seconds": strconv.Itoa(defaultMetadataTokenDuration), }, ConnectTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, } - res, err := httpDo(req) - if err != nil { - err = fmt.Errorf("get metadata token failed: %s", err.Error()) + res, _err := httpDo(req) + if _err != nil { + if provider.disableIMDSv1 { + err = fmt.Errorf("get metadata token failed: %s", _err.Error()) + } + return + } + if res.StatusCode != 200 { + if provider.disableIMDSv1 { + err = fmt.Errorf("refresh Ecs sts token err, httpStatus: %d, message = %s", res.StatusCode, string(res.Body)) + } return } metadataToken = string(res.Body) diff --git a/credentials/internal/providers/ecs_ram_role_test.go b/credentials/internal/providers/ecs_ram_role_test.go index d4cf538..7fc987c 100644 --- a/credentials/internal/providers/ecs_ram_role_test.go +++ b/credentials/internal/providers/ecs_ram_role_test.go @@ -2,6 +2,7 @@ package providers import ( "errors" + "os" "testing" "time" @@ -13,15 +14,10 @@ func TestNewECSRAMRoleCredentialsProvider(t *testing.T) { p, err := NewECSRAMRoleCredentialsProviderBuilder().Build() assert.Nil(t, err) assert.Equal(t, "", p.roleName) - assert.Equal(t, 21600, p.metadataTokenDurationSeconds) - _, err = NewECSRAMRoleCredentialsProviderBuilder().WithMetadataTokenDurationSeconds(1000000000).Build() - assert.EqualError(t, err, "the metadata token duration seconds must be 1-21600") - - p, err = NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("role").WithMetadataTokenDurationSeconds(3600).Build() + p, err = NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("role").Build() assert.Nil(t, err) assert.Equal(t, "role", p.roleName) - assert.Equal(t, 3600, p.metadataTokenDurationSeconds) assert.True(t, p.needUpdateCredential()) } @@ -73,7 +69,7 @@ func TestECSRAMRoleCredentialsProvider_getRoleNameWithMetadataV2(t *testing.T) { originHttpDo := httpDo defer func() { httpDo = originHttpDo }() - p, err := NewECSRAMRoleCredentialsProviderBuilder().WithEnableIMDSv2(true).Build() + p, err := NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build() assert.Nil(t, err) // case 1: get metadata token failed @@ -281,7 +277,7 @@ func TestECSRAMRoleCredentialsProvider_getCredentialsWithMetadataV2(t *testing.T originHttpDo := httpDo defer func() { httpDo = originHttpDo }() - p, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("rolename").WithEnableIMDSv2(true).Build() + p, err := NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).WithRoleName("rolename").Build() assert.Nil(t, err) // case 1: get metadata token failed @@ -383,9 +379,37 @@ func TestECSRAMRoleCredentialsProvider_getMetadataToken(t *testing.T) { return } + _, err = p.getMetadataToken() + assert.Nil(t, err) + + p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(false).Build() + assert.Nil(t, err) + + _, err = p.getMetadataToken() + assert.Nil(t, err) + + os.Setenv("ALIBABA_CLOUD_IMDSV1_DISABLE", "true") + p, err = NewECSRAMRoleCredentialsProviderBuilder().Build() + assert.Nil(t, err) + + _, err = p.getMetadataToken() + assert.NotNil(t, err) + + os.Setenv("ALIBABA_CLOUD_IMDSV1_DISABLE", "") + p, err = NewECSRAMRoleCredentialsProviderBuilder().Build() + assert.Nil(t, err) + + _, err = p.getMetadataToken() + assert.Nil(t, err) + + p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build() + assert.Nil(t, err) + _, err = p.getMetadataToken() assert.NotNil(t, err) + assert.Equal(t, "get metadata token failed: mock server error", err.Error()) + // case 2: return token httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { res = &httputil.Response{ @@ -397,4 +421,26 @@ func TestECSRAMRoleCredentialsProvider_getMetadataToken(t *testing.T) { metadataToken, err := p.getMetadataToken() assert.Nil(t, err) assert.Equal(t, "tokenxxxxx", metadataToken) + + // case 3: return 404 + p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(false).Build() + assert.Nil(t, err) + + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 404, + Body: []byte("not found"), + } + return + } + metadataToken, err = p.getMetadataToken() + assert.Nil(t, err) + assert.Equal(t, "", metadataToken) + + p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build() + assert.Nil(t, err) + + metadataToken, err = p.getMetadataToken() + assert.NotNil(t, err) + assert.Equal(t, "", metadataToken) }