From 412f4ed69e17d2e2a2bef8e5ed258843c6a0fe61 Mon Sep 17 00:00:00 2001 From: Jackson Tian Date: Tue, 20 Aug 2024 00:03:01 +0800 Subject: [PATCH] refine the ecs ram role credentials provider --- credentials/credential.go | 20 +- .../internal/providers/ecs_ram_role.go | 253 +++++++++ .../internal/providers/ecs_ram_role_test.go | 484 ++++++++++++++++++ 3 files changed, 747 insertions(+), 10 deletions(-) create mode 100644 credentials/internal/providers/ecs_ram_role.go create mode 100644 credentials/internal/providers/ecs_ram_role_test.go diff --git a/credentials/credential.go b/credentials/credential.go index b374289..70e6c60 100644 --- a/credentials/credential.go +++ b/credentials/credential.go @@ -253,17 +253,17 @@ func NewCredential(config *Config) (credential Credential, err error) { credential = fromCredentialsProvider("sts", provider) case "ecs_ram_role": - runtime := &utils.Runtime{ - Host: tea.StringValue(config.Host), - ReadTimeout: tea.IntValue(config.Timeout), - ConnectTimeout: tea.IntValue(config.ConnectTimeout), + provider, err := providers.NewECSRAMRoleCredentialsProviderBuilder(). + WithRoleName(tea.StringValue(config.RoleName)). + WithEnableIMDSv2(tea.BoolValue(config.EnableIMDSv2)). + WithMetadataTokenDurationSeconds(tea.IntValue(config.MetadataTokenDuration)). + Build() + + if err != nil { + return nil, err } - credential = newEcsRAMRoleCredentialWithEnableIMDSv2( - tea.StringValue(config.RoleName), - tea.BoolValue(config.EnableIMDSv2), - tea.IntValue(config.MetadataTokenDuration), - tea.Float64Value(config.InAdvanceScale), - runtime) + + credential = fromCredentialsProvider("ecs_ram_role", provider) case "ram_role_arn": var credentialsProvider providers.CredentialsProvider if config.SecurityToken != nil { diff --git a/credentials/internal/providers/ecs_ram_role.go b/credentials/internal/providers/ecs_ram_role.go new file mode 100644 index 0000000..25fd105 --- /dev/null +++ b/credentials/internal/providers/ecs_ram_role.go @@ -0,0 +1,253 @@ +package providers + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/aliyun/credentials-go/credentials/internal/utils" + "github.com/aliyun/credentials-go/credentials/request" +) + +const securityCredURL = "http://100.100.100.200/latest/meta-data/ram/security-credentials/" +const apiTokenURL = "http://100.100.100.200/latest/api/token" + +type ECSRAMRoleCredentialsProvider struct { + roleName string + metadataTokenDurationSeconds int + enableIMDSv2 bool + runtime *utils.Runtime + // for sts + session *sessionCredentials + expirationTimestamp int64 +} + +type ECSRAMRoleCredentialsProviderBuilder struct { + provider *ECSRAMRoleCredentialsProvider +} + +func NewECSRAMRoleCredentialsProviderBuilder() *ECSRAMRoleCredentialsProviderBuilder { + return &ECSRAMRoleCredentialsProviderBuilder{ + provider: &ECSRAMRoleCredentialsProvider{ + // TBD: 默认启用 IMDS v2 + // enableIMDSv2: os.Getenv("ALIBABA_CLOUD_IMDSV2_DISABLED") != "true", // 默认启用 v2 + }, + } +} + +func (builder *ECSRAMRoleCredentialsProviderBuilder) WithMetadataTokenDurationSeconds(metadataTokenDurationSeconds int) *ECSRAMRoleCredentialsProviderBuilder { + builder.provider.metadataTokenDurationSeconds = metadataTokenDurationSeconds + return builder +} + +func (builder *ECSRAMRoleCredentialsProviderBuilder) WithRoleName(roleName string) *ECSRAMRoleCredentialsProviderBuilder { + builder.provider.roleName = roleName + return builder +} + +func (builder *ECSRAMRoleCredentialsProviderBuilder) WithEnableIMDSv2(enableIMDSv2 bool) *ECSRAMRoleCredentialsProviderBuilder { + builder.provider.enableIMDSv2 = enableIMDSv2 + return builder +} + +const defaultMetadataTokenDuration = 21600 // 6 hours + +func (builder *ECSRAMRoleCredentialsProviderBuilder) Build() (provider *ECSRAMRoleCredentialsProvider, err error) { + // 设置 roleName 默认值 + if builder.provider.roleName == "" { + builder.provider.roleName = os.Getenv("ALIBABA_CLOUD_ECS_METADATA") + } + + if builder.provider.metadataTokenDurationSeconds == 0 { + builder.provider.metadataTokenDurationSeconds = defaultMetadataTokenDuration + } + + if builder.provider.metadataTokenDurationSeconds < 1 || builder.provider.metadataTokenDurationSeconds > 21600 { + err = errors.New("the metadata token duration seconds must be 1-21600") + return + } + + builder.provider.runtime = &utils.Runtime{ + ConnectTimeout: 5, + ReadTimeout: 5, + } + + provider = builder.provider + return +} + +type ecsRAMRoleResponse struct { + Code *string `json:"Code"` + AccessKeyId *string `json:"AccessKeyId"` + AccessKeySecret *string `json:"AccessKeySecret"` + SecurityToken *string `json:"SecurityToken"` + LastUpdated *string `json:"LastUpdated"` + Expiration *string `json:"Expiration"` +} + +func (provider *ECSRAMRoleCredentialsProvider) needUpdateCredential() bool { + if provider.expirationTimestamp == 0 { + return true + } + + return provider.expirationTimestamp-time.Now().Unix() <= 180 +} + +func (provider *ECSRAMRoleCredentialsProvider) getRoleName() (roleName string, err error) { + httpRequest, err := hookNewRequest(http.NewRequest)("GET", securityCredURL, strings.NewReader("")) + if err != nil { + err = fmt.Errorf("get role name failed: %s", err.Error()) + return + } + + if provider.enableIMDSv2 { + metadataToken, err := provider.getMetadataToken() + if err != nil { + return "", err + } + httpRequest.Header.Set("x-aliyun-ecs-metadata-token", metadataToken) + } + + httpClient := &http.Client{ + Timeout: 5 * time.Second, + } + httpResponse, err := hookDo(httpClient.Do)(httpRequest) + if err != nil { + err = fmt.Errorf("get role name failed: %s", err.Error()) + return + } + + if httpResponse.StatusCode != http.StatusOK { + err = fmt.Errorf("get role name failed: request %s %d", securityCredURL, httpResponse.StatusCode) + return + } + + defer httpResponse.Body.Close() + + responseBody, err := ioutil.ReadAll(httpResponse.Body) + if err != nil { + return + } + + roleName = strings.TrimSpace(string(responseBody)) + return +} + +func (provider *ECSRAMRoleCredentialsProvider) getCredentials() (session *sessionCredentials, err error) { + roleName := provider.roleName + if roleName == "" { + roleName, err = provider.getRoleName() + if err != nil { + return + } + } + + requestUrl := securityCredURL + roleName + httpRequest, err := hookNewRequest(http.NewRequest)("GET", requestUrl, strings.NewReader("")) + if err != nil { + err = fmt.Errorf("refresh Ecs sts token err: %s", err.Error()) + return + } + + if provider.enableIMDSv2 { + metadataToken, err := provider.getMetadataToken() + if err != nil { + return nil, err + } + httpRequest.Header.Set("x-aliyun-ecs-metadata-token", metadataToken) + } + + httpClient := &http.Client{ + Timeout: 5 * time.Second, + } + httpResponse, err := hookDo(httpClient.Do)(httpRequest) + if err != nil { + err = fmt.Errorf("refresh Ecs sts token err: %s", err.Error()) + return + } + + defer httpResponse.Body.Close() + + responseBody, err := ioutil.ReadAll(httpResponse.Body) + if err != nil { + return + } + + if httpResponse.StatusCode != http.StatusOK { + err = fmt.Errorf("refresh Ecs sts token err, httpStatus: %d, message = %s", httpResponse.StatusCode, string(responseBody)) + return + } + + var data ecsRAMRoleResponse + err = json.Unmarshal(responseBody, &data) + if err != nil { + err = fmt.Errorf("refresh Ecs sts token err, json.Unmarshal fail: %s", err.Error()) + return + } + + if data.AccessKeyId == nil || data.AccessKeySecret == nil || data.SecurityToken == nil { + err = fmt.Errorf("refresh Ecs sts token err, fail to get credentials") + return + } + + if *data.Code != "Success" { + err = fmt.Errorf("refresh Ecs sts token err, Code is not Success") + return + } + + session = &sessionCredentials{ + AccessKeyId: *data.AccessKeyId, + AccessKeySecret: *data.AccessKeySecret, + SecurityToken: *data.SecurityToken, + Expiration: *data.Expiration, + } + return +} + +func (provider *ECSRAMRoleCredentialsProvider) GetCredentials() (cc *Credentials, err error) { + if provider.session == nil || provider.needUpdateCredential() { + session, err1 := provider.getCredentials() + if err1 != nil { + return nil, err1 + } + + provider.session = session + expirationTime, err2 := time.Parse("2006-01-02T15:04:05Z", session.Expiration) + if err2 != nil { + return nil, err2 + } + provider.expirationTimestamp = expirationTime.Unix() + } + + cc = &Credentials{ + AccessKeyId: provider.session.AccessKeyId, + AccessKeySecret: provider.session.AccessKeySecret, + SecurityToken: provider.session.SecurityToken, + ProviderName: provider.GetProviderName(), + } + return +} + +func (provider *ECSRAMRoleCredentialsProvider) GetProviderName() string { + return "ecs_ram_role" +} + +func (provider *ECSRAMRoleCredentialsProvider) getMetadataToken() (metadataToken string, err error) { + request := request.NewCommonRequest() + request.URL = apiTokenURL + request.Method = "PUT" + request.Headers["X-aliyun-ecs-metadata-token-ttl-seconds"] = strconv.Itoa(provider.metadataTokenDurationSeconds) + content, err := doAction(request, provider.runtime) + if err != nil { + err = fmt.Errorf("get metadata token failed: %s", err.Error()) + return + } + metadataToken = string(content) + return +} diff --git a/credentials/internal/providers/ecs_ram_role_test.go b/credentials/internal/providers/ecs_ram_role_test.go new file mode 100644 index 0000000..36f03bd --- /dev/null +++ b/credentials/internal/providers/ecs_ram_role_test.go @@ -0,0 +1,484 @@ +package providers + +import ( + "errors" + "io" + "io/ioutil" + "net/http" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewECSRAMRoleCredentialsProvider(t *testing.T) { + p, err := NewECSRAMRoleCredentialsProviderBuilder().Build() + assert.Nil(t, err) + assert.Equal(t, "", p.roleName) + assert.Equal(t, 21600, p.metadataTokenDurationSeconds) + + _, err = NewECSRAMRoleCredentialsProviderBuilder().WithMetadataTokenDurationSeconds(1000000000).Build() + assert.EqualError(t, err, "the metadata token duration seconds must be 1-21600") + + p, err = NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("role").WithMetadataTokenDurationSeconds(3600).Build() + assert.Nil(t, err) + assert.Equal(t, "role", p.roleName) + assert.Equal(t, 3600, p.metadataTokenDurationSeconds) + + assert.True(t, p.needUpdateCredential()) +} + +func TestECSRAMRoleCredentialsProvider_getRoleName(t *testing.T) { + p, err := NewECSRAMRoleCredentialsProviderBuilder().Build() + assert.Nil(t, err) + + originNewRequest := hookNewRequest + defer func() { hookNewRequest = originNewRequest }() + + // case 1: mock new http request failed + hookNewRequest = func(fn newReuqest) newReuqest { + return func(method, url string, body io.Reader) (*http.Request, error) { + return nil, errors.New("new http request failed") + } + } + _, err = p.getRoleName() + assert.NotNil(t, err) + assert.Equal(t, "get role name failed: new http request failed", err.Error()) + // reset new request + hookNewRequest = originNewRequest + + originDo := hookDo + defer func() { hookDo = originDo }() + + // case 2: server error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = p.getRoleName() + assert.NotNil(t, err) + assert.Equal(t, "get role name failed: mock server error", err.Error()) + + // case 3: 4xx error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(400, "4xx error") + return + } + } + + _, err = p.getRoleName() + assert.NotNil(t, err) + assert.Equal(t, "get role name failed: request http://100.100.100.200/latest/meta-data/ram/security-credentials/ 400", err.Error()) + + // case 4: mock read response error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + status := strconv.Itoa(200) + res = &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + Header: map[string][]string{}, + StatusCode: 200, + Status: status + " " + http.StatusText(200), + } + res.Body = ioutil.NopCloser(&errorReader{}) + return + } + } + _, err = p.getRoleName() + assert.NotNil(t, err) + assert.Equal(t, "read failed", err.Error()) + + // case 5: value json + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, "rolename") + return + } + } + roleName, err := p.getRoleName() + assert.Nil(t, err) + assert.Equal(t, "rolename", roleName) +} + +func TestECSRAMRoleCredentialsProvider_getRoleNameWithMetadataV2(t *testing.T) { + p, err := NewECSRAMRoleCredentialsProviderBuilder().WithEnableIMDSv2(true).Build() + assert.Nil(t, err) + + // case 1: get metadata token failed + originDo := hookDo + defer func() { hookDo = originDo }() + + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = p.getRoleName() + assert.NotNil(t, err) + assert.Equal(t, "get metadata token failed: mock server error", err.Error()) + + // case 2: return token + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/api/token" { + res = mockResponse(200, `tokenxxxxx`) + } else { + assert.Equal(t, "tokenxxxxx", req.Header.Get("x-aliyun-ecs-metadata-token")) + res = mockResponse(200, "rolename") + } + return + } + } + + roleName, err := p.getRoleName() + assert.Nil(t, err) + assert.Equal(t, "rolename", roleName) +} + +func TestECSRAMRoleCredentialsProvider_getCredentials(t *testing.T) { + originDo := hookDo + defer func() { hookDo = originDo }() + + p, err := NewECSRAMRoleCredentialsProviderBuilder().Build() + assert.Nil(t, err) + + // case 1: server error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "get role name failed: mock server error", err.Error()) + + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + err = errors.New("mock server error") + return + } + } + + originNewRequest := hookNewRequest + defer func() { hookNewRequest = originNewRequest }() + + // case 2: mock new http request failed + hookNewRequest = func(fn newReuqest) newReuqest { + return func(method, url string, body io.Reader) (*http.Request, error) { + if url == "http://100.100.100.200/latest/meta-data/ram/security-credentials/rolename" { + return nil, errors.New("new http request failed") + } + return http.NewRequest(method, url, body) + } + } + + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh Ecs sts token err: new http request failed", err.Error()) + + hookNewRequest = originNewRequest + + // case 3 + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + err = errors.New("mock server error") + return + } + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh Ecs sts token err: mock server error", err.Error()) + + // case 4: mock read response error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + status := strconv.Itoa(200) + res = &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + Header: map[string][]string{}, + StatusCode: 200, + Status: status + " " + http.StatusText(200), + } + res.Body = ioutil.NopCloser(&errorReader{}) + return + } + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "read failed", err.Error()) + + // case 4: 4xx error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + res = mockResponse(400, "4xx error") + return + } + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh Ecs sts token err, httpStatus: 400, message = 4xx error", err.Error()) + + // case 5: invalid json + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + res = mockResponse(200, "invalid json") + return + } + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh Ecs sts token err, json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error()) + + // case 6: empty response json + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + res = mockResponse(200, "null") + return + } + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh Ecs sts token err, fail to get credentials", err.Error()) + + // case 7: empty session ak response json + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + res = mockResponse(200, `{}`) + return + } + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh Ecs sts token err, fail to get credentials", err.Error()) + + // case 8: non-success response + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + res = mockResponse(200, `{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Failed"}`) + return + } + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh Ecs sts token err, Code is not Success", err.Error()) + + // case 8: mock ok value + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { + res = mockResponse(200, "rolename") + return + } + + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + res = mockResponse(200, `{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`) + return + } + return + } + } + creds, err := p.getCredentials() + assert.Nil(t, err) + assert.Equal(t, "saki", creds.AccessKeyId) + assert.Equal(t, "saks", creds.AccessKeySecret) + assert.Equal(t, "token", creds.SecurityToken) + assert.Equal(t, "2021-10-20T04:27:09Z", creds.Expiration) + + // needUpdateCredential + assert.True(t, p.needUpdateCredential()) + p.expirationTimestamp = time.Now().Unix() + assert.True(t, p.needUpdateCredential()) + + p.expirationTimestamp = time.Now().Unix() + 300 + assert.False(t, p.needUpdateCredential()) +} + +func TestECSRAMRoleCredentialsProvider_getCredentialsWithMetadataV2(t *testing.T) { + originDo := hookDo + defer func() { hookDo = originDo }() + + p, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("rolename").WithEnableIMDSv2(true).Build() + assert.Nil(t, err) + + // case 1: get metadata token failed + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "get metadata token failed: mock server error", err.Error()) + + // case 2: return token + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + if req.URL.Path == "/latest/api/token" { + res = mockResponse(200, `tokenxxxxx`) + return + } + if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { + assert.Equal(t, "tokenxxxxx", req.Header.Get("x-aliyun-ecs-metadata-token")) + res = mockResponse(200, `{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`) + } + return + } + } + + creds, err := p.getCredentials() + assert.Nil(t, err) + assert.Equal(t, "saki", creds.AccessKeyId) + assert.Equal(t, "saks", creds.AccessKeySecret) + assert.Equal(t, "token", creds.SecurityToken) + assert.Equal(t, "2021-10-20T04:27:09Z", creds.Expiration) + + // needUpdateCredential + assert.True(t, p.needUpdateCredential()) + p.expirationTimestamp = time.Now().Unix() + assert.True(t, p.needUpdateCredential()) + + p.expirationTimestamp = time.Now().Unix() + 300 + assert.False(t, p.needUpdateCredential()) +} + +func TestECSRAMRoleCredentialsProviderGetCredentials(t *testing.T) { + originDo := hookDo + defer func() { hookDo = originDo }() + + p, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("rolename").Build() + assert.Nil(t, err) + // case 1: get credentials failed + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = p.GetCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh Ecs sts token err: mock server error", err.Error()) + + // case 2: get invalid expiration + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, `{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"invalidexpiration","SecurityToken":"token","Code":"Success"}`) + return + } + } + _, err = p.GetCredentials() + assert.NotNil(t, err) + assert.Equal(t, "parsing time \"invalidexpiration\" as \"2006-01-02T15:04:05Z\": cannot parse \"invalidexpiration\" as \"2006\"", err.Error()) + + // case 3: happy result + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, `{"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`) + return + } + } + cc, err := p.GetCredentials() + assert.Nil(t, err) + assert.Equal(t, "akid", cc.AccessKeyId) + assert.Equal(t, "aksecret", cc.AccessKeySecret) + assert.Equal(t, "token", cc.SecurityToken) + assert.True(t, p.needUpdateCredential()) +} + +func TestECSRAMRoleCredentialsProvider_getMetadataToken(t *testing.T) { + originDo := hookDo + defer func() { hookDo = originDo }() + + p, err := NewECSRAMRoleCredentialsProviderBuilder().Build() + assert.Nil(t, err) + + // case 1: server error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = p.getMetadataToken() + assert.NotNil(t, err) + assert.Equal(t, "get metadata token failed: mock server error", err.Error()) + // case 2: return token + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, `tokenxxxxx`) + return + } + } + metadataToken, err := p.getMetadataToken() + assert.Nil(t, err) + assert.Equal(t, "tokenxxxxx", metadataToken) +}