Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support IMDSv2 for ECS metadata #69

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Config struct {
RoleSessionName *string `json:"role_session_name"`
PublicKeyId *string `json:"public_key_id"`
RoleName *string `json:"role_name"`
EnableIMDSv2 *bool `json:"enable_imds_v2"`
MetadataTokenDuration *int `json:"metadata_token_duration"`
SessionExpiration *int `json:"session_expiration"`
PrivateKeyFile *string `json:"private_key_file"`
BearerToken *string `json:"bearer_token"`
Expand Down Expand Up @@ -106,6 +108,16 @@ func (s *Config) SetRoleName(v string) *Config {
return s
}

func (s *Config) SetEnableIMDSv2(v bool) *Config {
s.EnableIMDSv2 = &v
return s
}

func (s *Config) SetMetadataTokenDuration(v int) *Config {
s.MetadataTokenDuration = &v
return s
}

func (s *Config) SetSessionExpiration(v int) *Config {
s.SessionExpiration = &v
return s
Expand Down Expand Up @@ -205,19 +217,33 @@ func NewCredential(config *Config) (credential Credential, err error) {
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
STSEndpoint: tea.StringValue(config.STSEndpoint),
}
credential = newOIDCRoleArnCredential(tea.StringValue(config.AccessKeyId), tea.StringValue(config.AccessKeySecret), tea.StringValue(config.RoleArn), tea.StringValue(config.OIDCProviderArn), tea.StringValue(config.OIDCTokenFilePath), tea.StringValue(config.RoleSessionName), tea.StringValue(config.Policy), tea.IntValue(config.RoleSessionExpiration), runtime)
credential = newOIDCRoleArnCredential(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret),
tea.StringValue(config.RoleArn),
tea.StringValue(config.OIDCProviderArn),
tea.StringValue(config.OIDCTokenFilePath),
tea.StringValue(config.RoleSessionName),
tea.StringValue(config.Policy),
tea.IntValue(config.RoleSessionExpiration),
runtime)
case "access_key":
err = checkAccessKey(config)
if err != nil {
return
}
credential = newAccessKeyCredential(tea.StringValue(config.AccessKeyId), tea.StringValue(config.AccessKeySecret))
credential = newAccessKeyCredential(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret))
case "sts":
err = checkSTS(config)
if err != nil {
return
}
credential = newStsTokenCredential(tea.StringValue(config.AccessKeyId), tea.StringValue(config.AccessKeySecret), tea.StringValue(config.SecurityToken))
credential = newStsTokenCredential(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret),
tea.StringValue(config.SecurityToken))
case "ecs_ram_role":
checkEcsRAMRole(config)
runtime := &utils.Runtime{
Expand All @@ -226,7 +252,12 @@ func NewCredential(config *Config) (credential Credential, err error) {
ReadTimeout: tea.IntValue(config.Timeout),
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
}
credential = newEcsRAMRoleCredential(tea.StringValue(config.RoleName), tea.Float64Value(config.InAdvanceScale), runtime)
credential = newEcsRAMRoleCredentialWithEnableIMDSv2(
tea.StringValue(config.RoleName),
tea.BoolValue(config.EnableIMDSv2),
tea.IntValue(config.MetadataTokenDuration),
tea.Float64Value(config.InAdvanceScale),
runtime)
case "ram_role_arn":
err = checkRAMRoleArn(config)
if err != nil {
Expand Down Expand Up @@ -274,7 +305,11 @@ func NewCredential(config *Config) (credential Credential, err error) {
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
STSEndpoint: tea.StringValue(config.STSEndpoint),
}
credential = newRsaKeyPairCredential(privateKey, tea.StringValue(config.PublicKeyId), tea.IntValue(config.SessionExpiration), runtime)
credential = newRsaKeyPairCredential(
privateKey,
tea.StringValue(config.PublicKeyId),
tea.IntValue(config.SessionExpiration),
runtime)
case "bearer":
if tea.StringValue(config.BearerToken) == "" {
err = errors.New("BearerToken cannot be empty")
Expand Down
20 changes: 18 additions & 2 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ this is privatekey`

func TestConfig(t *testing.T) {
config := new(Config)
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())

config.SetSTSEndpoint("sts.cn-hangzhou.aliyuncs.com")
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", *config.STSEndpoint)
Expand Down Expand Up @@ -96,6 +96,22 @@ func TestNewCredentialWithECSRAMRole(t *testing.T) {
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)

config.SetEnableIMDSv2(false)
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)

config.SetEnableIMDSv2(true)
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)

config.SetEnableIMDSv2(true)
config.SetMetadataTokenDuration(180)
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)
}

func TestNewCredentialWithRSAKeyPair(t *testing.T) {
Expand Down
60 changes: 57 additions & 3 deletions credentials/ecs_ram_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package credentials
import (
"encoding/json"
"fmt"
"strconv"
"time"

"github.com/alibabacloud-go/tea/tea"
Expand All @@ -11,13 +12,20 @@ import (
)

var securityCredURL = "http://100.100.100.200/latest/meta-data/ram/security-credentials/"
var securityCredTokenURL = "http://100.100.100.200/latest/api/token"

const defaultMetadataTokenDuration = int(21600)

// EcsRAMRoleCredential is a kind of credential
type EcsRAMRoleCredential struct {
*credentialUpdater
RoleName string
sessionCredential *sessionCredential
runtime *utils.Runtime
RoleName string
EnableIMDSv2 bool
MetadataTokenDuration int
sessionCredential *sessionCredential
runtime *utils.Runtime
metadataToken string
staleTime int64
}

type ecsRAMRoleResponse struct {
Expand All @@ -40,6 +48,20 @@ func newEcsRAMRoleCredential(roleName string, inAdvanceScale float64, runtime *u
}
}

func newEcsRAMRoleCredentialWithEnableIMDSv2(roleName string, enableIMDSv2 bool, metadataTokenDuration int, inAdvanceScale float64, runtime *utils.Runtime) *EcsRAMRoleCredential {
credentialUpdater := new(credentialUpdater)
if inAdvanceScale < 1 && inAdvanceScale > 0 {
credentialUpdater.inAdvanceScale = inAdvanceScale
}
return &EcsRAMRoleCredential{
RoleName: roleName,
EnableIMDSv2: enableIMDSv2,
MetadataTokenDuration: metadataTokenDuration,
credentialUpdater: credentialUpdater,
runtime: runtime,
}
}

func (e *EcsRAMRoleCredential) GetCredential() (*CredentialModel, error) {
if e.sessionCredential == nil || e.needUpdateCredential() {
err := e.updateCredential()
Expand Down Expand Up @@ -123,6 +145,26 @@ func getRoleName() (string, error) {
return string(content), nil
}

func (e *EcsRAMRoleCredential) getMetadataToken() (err error) {
if e.needToRefresh() {
if e.MetadataTokenDuration <= 0 {
e.MetadataTokenDuration = defaultMetadataTokenDuration
}
tmpTime := time.Now().Unix() + int64(e.MetadataTokenDuration*1000)
request := request.NewCommonRequest()
request.URL = securityCredTokenURL
request.Method = "PUT"
request.Headers["X-aliyun-ecs-metadata-token-ttl-seconds"] = strconv.Itoa(e.MetadataTokenDuration)
content, err := doAction(request, e.runtime)
if err != nil {
return err
}
e.staleTime = tmpTime
e.metadataToken = string(content)
}
return
}

func (e *EcsRAMRoleCredential) updateCredential() (err error) {
if e.runtime == nil {
e.runtime = new(utils.Runtime)
Expand All @@ -134,6 +176,13 @@ func (e *EcsRAMRoleCredential) updateCredential() (err error) {
return fmt.Errorf("refresh Ecs sts token err: %s", err.Error())
}
}
if e.EnableIMDSv2 {
err = e.getMetadataToken()
if err != nil {
return fmt.Errorf("Failed to get token from ECS Metadata Service: %s", err.Error())
}
request.Headers["X-aliyun-ecs-metadata-token"] = e.metadataToken
}
request.URL = securityCredURL + e.RoleName
request.Method = "GET"
content, err := doAction(request, e.runtime)
Expand Down Expand Up @@ -163,3 +212,8 @@ func (e *EcsRAMRoleCredential) updateCredential() (err error) {

return
}

func (e *EcsRAMRoleCredential) needToRefresh() (needToRefresh bool) {
needToRefresh = time.Now().Unix() >= e.staleTime
return
}
127 changes: 127 additions & 0 deletions credentials/ecs_ram_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,130 @@ func Test_EcsRAmRoleCredential(t *testing.T) {
assert.Equal(t, "refresh Ecs sts token err: error parse", err.Error())
assert.Equal(t, "", *accesskeyId)
}

func Test_EcsRAmRoleCredentialEnableIMDSv2(t *testing.T) {
auth := newEcsRAMRoleCredentialWithEnableIMDSv2("go sdk", false, 0, 0.5, nil)
origTestHookDo := hookDo
defer func() { hookDo = origTestHookDo }()

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(300, ``, errors.New("sdk test"))
}
}
accesskeyId, err := auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: sdk test", err.Error())
assert.Equal(t, "", *accesskeyId)

auth = newEcsRAMRoleCredentialWithEnableIMDSv2("go sdk", true, 0, 0.5, nil)
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "Failed to get token from ECS Metadata Service: sdk test", err.Error())
assert.Equal(t, "", *accesskeyId)

auth = newEcsRAMRoleCredentialWithEnableIMDSv2("go sdk", true, 180, 0.5, nil)
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "Failed to get token from ECS Metadata Service: sdk test", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(300, ``, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "Failed to get token from ECS Metadata Service: httpStatus: 300, message = ", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(400, `role`, nil)
}
}
auth.RoleName = ""
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: httpStatus: 400, message = role", err.Error())

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `role`, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: Json Unmarshal fail: invalid character 'r' looking for beginning of value", err.Error())
hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `"AccessKeyId":"accessKeyId","AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"expiration"`, nil)
}
}
auth.RoleName = "role"
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: Json Unmarshal fail: invalid character ':' after top-level value", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `{"AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"expiration","Code":"fail"}`, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: Code is not Success", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `{"AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"expiration","Code":"Success"}`, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: AccessKeyId: , AccessKeySecret: accessKeySecret, SecurityToken: securitytoken, Expiration: expiration", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `{"AccessKeyId":"accessKeyId","AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"2018-01-02T15:04:05Z","Code":"Success"}`, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.Nil(t, err)
assert.Equal(t, "accessKeyId", *accesskeyId)

accesskeySecret, err := auth.GetAccessKeySecret()
assert.Nil(t, err)
assert.Equal(t, "accessKeySecret", *accesskeySecret)

ststoken, err := auth.GetSecurityToken()
assert.Nil(t, err)
assert.Equal(t, "securitytoken", *ststoken)

err = errors.New("credentials")
err = hookParse(err)
assert.Equal(t, "credentials", err.Error())

cred, err := auth.GetCredential()
assert.Nil(t, err)
assert.Equal(t, "accessKeyId", *cred.AccessKeyId)
assert.Equal(t, "accessKeySecret", *cred.AccessKeySecret)
assert.Equal(t, "securitytoken", *cred.SecurityToken)
assert.Nil(t, cred.BearerToken)
assert.Equal(t, "ecs_ram_role", *cred.Type)

originHookParse := hookParse
hookParse = func(err error) error {
return errors.New("error parse")
}
defer func() {
hookParse = originHookParse
}()
accesskeyId, err = auth.GetAccessKeyId()
assert.Equal(t, "refresh Ecs sts token err: error parse", err.Error())
assert.Equal(t, "", *accesskeyId)
}
7 changes: 5 additions & 2 deletions credentials/instance_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package credentials

import (
"os"
"strings"

"github.com/alibabacloud-go/tea/tea"
)
Expand All @@ -19,10 +20,12 @@ func (p *instanceCredentialsProvider) resolve() (*Config, error) {
if !ok {
return nil, nil
}
enableIMDSv2, _ := os.LookupEnv(ENVEcsMetadataIMDSv2Enable)

config := &Config{
Type: tea.String("ecs_ram_role"),
RoleName: tea.String(roleName),
Type: tea.String("ecs_ram_role"),
RoleName: tea.String(roleName),
EnableIMDSv2: tea.Bool(strings.ToLower(enableIMDSv2) == "true"),
}
return config, nil
}
Loading
Loading