Skip to content

Commit

Permalink
feat: support IMDS v2 default for ecs ram role
Browse files Browse the repository at this point in the history
  • Loading branch information
yndu13 authored and JacksonTian committed Sep 12, 2024
1 parent af7c74d commit f2abf11
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 48 deletions.
4 changes: 2 additions & 2 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Config struct {
PublicKeyId *string `json:"public_key_id"`
RoleName *string `json:"role_name"`
EnableIMDSv2 *bool `json:"enable_imds_v2"`
DisableIMDSv1 *bool `json:"disable_imds_v1"`
MetadataTokenDuration *int `json:"metadata_token_duration"`
SessionExpiration *int `json:"session_expiration"`
PrivateKeyFile *string `json:"private_key_file"`
Expand Down Expand Up @@ -248,8 +249,7 @@ func NewCredential(config *Config) (credential Credential, err error) {
case "ecs_ram_role":
provider, err := providers.NewECSRAMRoleCredentialsProviderBuilder().
WithRoleName(tea.StringValue(config.RoleName)).
WithEnableIMDSv2(tea.BoolValue(config.EnableIMDSv2)).
WithMetadataTokenDurationSeconds(tea.IntValue(config.MetadataTokenDuration)).
WithDisableIMDSv1(tea.BoolValue(config.DisableIMDSv1)).
Build()

if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ this is privatekey`

func TestConfig(t *testing.T) {
config := new(Config)
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())

config.SetSTSEndpoint("sts.cn-hangzhou.aliyuncs.com")
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", *config.STSEndpoint)
Expand Down
66 changes: 30 additions & 36 deletions credentials/internal/providers/ecs_ram_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package providers

import (
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
Expand All @@ -13,9 +12,8 @@ import (
)

type ECSRAMRoleCredentialsProvider struct {
roleName string
metadataTokenDurationSeconds int
enableIMDSv2 bool
roleName string
disableIMDSv1 bool
// for sts
session *sessionCredentials
expirationTimestamp int64
Expand All @@ -27,43 +25,31 @@ type ECSRAMRoleCredentialsProviderBuilder struct {

func NewECSRAMRoleCredentialsProviderBuilder() *ECSRAMRoleCredentialsProviderBuilder {
return &ECSRAMRoleCredentialsProviderBuilder{
provider: &ECSRAMRoleCredentialsProvider{
// TBD: 默认启用 IMDS v2
// enableIMDSv2: os.Getenv("ALIBABA_CLOUD_IMDSV2_DISABLED") != "true", // 默认启用 v2
},
provider: &ECSRAMRoleCredentialsProvider{},
}
}

func (builder *ECSRAMRoleCredentialsProviderBuilder) WithMetadataTokenDurationSeconds(metadataTokenDurationSeconds int) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.metadataTokenDurationSeconds = metadataTokenDurationSeconds
return builder
}

func (builder *ECSRAMRoleCredentialsProviderBuilder) WithRoleName(roleName string) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.roleName = roleName
return builder
}

func (builder *ECSRAMRoleCredentialsProviderBuilder) WithEnableIMDSv2(enableIMDSv2 bool) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.enableIMDSv2 = enableIMDSv2
func (builder *ECSRAMRoleCredentialsProviderBuilder) WithDisableIMDSv1(disableIMDSv1 bool) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.disableIMDSv1 = disableIMDSv1
return builder
}

const defaultMetadataTokenDuration = 21600 // 6 hours

func (builder *ECSRAMRoleCredentialsProviderBuilder) Build() (provider *ECSRAMRoleCredentialsProvider, err error) {

// 设置 roleName 默认值
if builder.provider.roleName == "" {
builder.provider.roleName = os.Getenv("ALIBABA_CLOUD_ECS_METADATA")
}

if builder.provider.metadataTokenDurationSeconds == 0 {
builder.provider.metadataTokenDurationSeconds = defaultMetadataTokenDuration
}

if builder.provider.metadataTokenDurationSeconds < 1 || builder.provider.metadataTokenDurationSeconds > 21600 {
err = errors.New("the metadata token duration seconds must be 1-21600")
return
if !builder.provider.disableIMDSv1 {
builder.provider.disableIMDSv1 = os.Getenv("ALIBABA_CLOUD_IMDSV1_DISABLE") == "true"
}

provider = builder.provider
Expand Down Expand Up @@ -98,11 +84,11 @@ func (provider *ECSRAMRoleCredentialsProvider) getRoleName() (roleName string, e
Headers: map[string]string{},
}

if provider.enableIMDSv2 {
metadataToken, err := provider.getMetadataToken()
if err != nil {
return "", err
}
metadataToken, err := provider.getMetadataToken()
if err != nil {
return "", err
}
if metadataToken != "" {
req.Headers["x-aliyun-ecs-metadata-token"] = metadataToken
}

Expand Down Expand Up @@ -140,11 +126,11 @@ func (provider *ECSRAMRoleCredentialsProvider) getCredentials() (session *sessio
Headers: map[string]string{},
}

if provider.enableIMDSv2 {
metadataToken, err := provider.getMetadataToken()
if err != nil {
return nil, err
}
metadataToken, err := provider.getMetadataToken()
if err != nil {
return nil, err
}
if metadataToken != "" {
req.Headers["x-aliyun-ecs-metadata-token"] = metadataToken
}

Expand Down Expand Up @@ -221,14 +207,22 @@ func (provider *ECSRAMRoleCredentialsProvider) getMetadataToken() (metadataToken
Host: "100.100.100.200",
Path: "/latest/api/token",
Headers: map[string]string{
"X-aliyun-ecs-metadata-token-ttl-seconds": strconv.Itoa(provider.metadataTokenDurationSeconds),
"X-aliyun-ecs-metadata-token-ttl-seconds": strconv.Itoa(defaultMetadataTokenDuration),
},
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
}
res, err := httpDo(req)
if err != nil {
err = fmt.Errorf("get metadata token failed: %s", err.Error())
res, _err := httpDo(req)
if _err != nil {
if provider.disableIMDSv1 {
err = fmt.Errorf("get metadata token failed: %s", _err.Error())
}
return
}
if res.StatusCode != 200 {
if provider.disableIMDSv1 {
err = fmt.Errorf("refresh Ecs sts token err, httpStatus: %d, message = %s", res.StatusCode, string(res.Body))
}
return
}
metadataToken = string(res.Body)
Expand Down
62 changes: 54 additions & 8 deletions credentials/internal/providers/ecs_ram_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package providers

import (
"errors"
"os"
"testing"
"time"

Expand All @@ -13,15 +14,10 @@ func TestNewECSRAMRoleCredentialsProvider(t *testing.T) {
p, err := NewECSRAMRoleCredentialsProviderBuilder().Build()
assert.Nil(t, err)
assert.Equal(t, "", p.roleName)
assert.Equal(t, 21600, p.metadataTokenDurationSeconds)

_, err = NewECSRAMRoleCredentialsProviderBuilder().WithMetadataTokenDurationSeconds(1000000000).Build()
assert.EqualError(t, err, "the metadata token duration seconds must be 1-21600")

p, err = NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("role").WithMetadataTokenDurationSeconds(3600).Build()
p, err = NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("role").Build()
assert.Nil(t, err)
assert.Equal(t, "role", p.roleName)
assert.Equal(t, 3600, p.metadataTokenDurationSeconds)

assert.True(t, p.needUpdateCredential())
}
Expand Down Expand Up @@ -73,7 +69,7 @@ func TestECSRAMRoleCredentialsProvider_getRoleNameWithMetadataV2(t *testing.T) {
originHttpDo := httpDo
defer func() { httpDo = originHttpDo }()

p, err := NewECSRAMRoleCredentialsProviderBuilder().WithEnableIMDSv2(true).Build()
p, err := NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build()
assert.Nil(t, err)

// case 1: get metadata token failed
Expand Down Expand Up @@ -281,7 +277,7 @@ func TestECSRAMRoleCredentialsProvider_getCredentialsWithMetadataV2(t *testing.T
originHttpDo := httpDo
defer func() { httpDo = originHttpDo }()

p, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("rolename").WithEnableIMDSv2(true).Build()
p, err := NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).WithRoleName("rolename").Build()
assert.Nil(t, err)

// case 1: get metadata token failed
Expand Down Expand Up @@ -383,9 +379,37 @@ func TestECSRAMRoleCredentialsProvider_getMetadataToken(t *testing.T) {
return
}

_, err = p.getMetadataToken()
assert.Nil(t, err)

p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(false).Build()
assert.Nil(t, err)

_, err = p.getMetadataToken()
assert.Nil(t, err)

os.Setenv("ALIBABA_CLOUD_IMDSV1_DISABLE", "true")
p, err = NewECSRAMRoleCredentialsProviderBuilder().Build()
assert.Nil(t, err)

_, err = p.getMetadataToken()
assert.NotNil(t, err)

os.Setenv("ALIBABA_CLOUD_IMDSV1_DISABLE", "")
p, err = NewECSRAMRoleCredentialsProviderBuilder().Build()
assert.Nil(t, err)

_, err = p.getMetadataToken()
assert.Nil(t, err)

p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build()
assert.Nil(t, err)

_, err = p.getMetadataToken()
assert.NotNil(t, err)

assert.Equal(t, "get metadata token failed: mock server error", err.Error())

// case 2: return token
httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
res = &httputil.Response{
Expand All @@ -397,4 +421,26 @@ func TestECSRAMRoleCredentialsProvider_getMetadataToken(t *testing.T) {
metadataToken, err := p.getMetadataToken()
assert.Nil(t, err)
assert.Equal(t, "tokenxxxxx", metadataToken)

// case 3: return 404
p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(false).Build()
assert.Nil(t, err)

httpDo = func(req *httputil.Request) (res *httputil.Response, err error) {
res = &httputil.Response{
StatusCode: 404,
Body: []byte("not found"),
}
return
}
metadataToken, err = p.getMetadataToken()
assert.Nil(t, err)
assert.Equal(t, "", metadataToken)

p, err = NewECSRAMRoleCredentialsProviderBuilder().WithDisableIMDSv1(true).Build()
assert.Nil(t, err)

metadataToken, err = p.getMetadataToken()
assert.NotNil(t, err)
assert.Equal(t, "", metadataToken)
}

0 comments on commit f2abf11

Please sign in to comment.