Skip to content

Commit

Permalink
feat: add imds and external uri in default chain && resolve credentia…
Browse files Browse the repository at this point in the history
…ls timeout
  • Loading branch information
yndu13 committed Nov 5, 2024
1 parent 25ec51c commit 571bb67
Show file tree
Hide file tree
Showing 13 changed files with 583 additions and 68 deletions.
72 changes: 45 additions & 27 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,50 @@ type Credential interface {

// Config is important when call NewCredential
type Config struct {
Type *string `json:"type"`
AccessKeyId *string `json:"access_key_id"`
AccessKeySecret *string `json:"access_key_secret"`
OIDCProviderArn *string `json:"oidc_provider_arn"`
OIDCTokenFilePath *string `json:"oidc_token"`
RoleArn *string `json:"role_arn"`
RoleSessionName *string `json:"role_session_name"`
PublicKeyId *string `json:"public_key_id"`
RoleName *string `json:"role_name"`
EnableIMDSv2 *bool `json:"enable_imds_v2"`
DisableIMDSv1 *bool `json:"disable_imds_v1"`
MetadataTokenDuration *int `json:"metadata_token_duration"`
SessionExpiration *int `json:"session_expiration"`
PrivateKeyFile *string `json:"private_key_file"`
BearerToken *string `json:"bearer_token"`
SecurityToken *string `json:"security_token"`
RoleSessionExpiration *int `json:"role_session_expiration"`
Policy *string `json:"policy"`
Host *string `json:"host"`
Timeout *int `json:"timeout"`
ConnectTimeout *int `json:"connect_timeout"`
Proxy *string `json:"proxy"`
InAdvanceScale *float64 `json:"inAdvanceScale"`
Url *string `json:"url"`
STSEndpoint *string `json:"sts_endpoint"`
ExternalId *string `json:"external_id"`
// Credential type, including access_key, sts, bearer, ecs_ram_role, ram_role_arn, rsa_key_pair, oidc_role_arn, credentials_uri
Type *string `json:"type"`
AccessKeyId *string `json:"access_key_id"`
AccessKeySecret *string `json:"access_key_secret"`
SecurityToken *string `json:"security_token"`
BearerToken *string `json:"bearer_token"`

// Used when the type is ram_role_arn or oidc_role_arn
OIDCProviderArn *string `json:"oidc_provider_arn"`
OIDCTokenFilePath *string `json:"oidc_token"`
RoleArn *string `json:"role_arn"`
RoleSessionName *string `json:"role_session_name"`
RoleSessionExpiration *int `json:"role_session_expiration"`
Policy *string `json:"policy"`
ExternalId *string `json:"external_id"`
STSEndpoint *string `json:"sts_endpoint"`

// Used when the type is ecs_ram_role
RoleName *string `json:"role_name"`
// Deprecated
EnableIMDSv2 *bool `json:"enable_imds_v2"`
DisableIMDSv1 *bool `json:"disable_imds_v1"`
// Deprecated
MetadataTokenDuration *int `json:"metadata_token_duration"`

// Used when the type is credentials_uri
Url *string `json:"url"`

// Deprecated
// Used when the type is rsa_key_pair
SessionExpiration *int `json:"session_expiration"`
PublicKeyId *string `json:"public_key_id"`
PrivateKeyFile *string `json:"private_key_file"`
Host *string `json:"host"`

// Read timeout, in milliseconds.
// The default value for ecs_ram_role is 1000ms, the default value for ram_role_arn is 5000ms, and the default value for oidc_role_arn is 5000ms.
Timeout *int `json:"timeout"`
// Connection timeout, in milliseconds.
// The default value for ecs_ram_role is 1000ms, the default value for ram_role_arn is 10000ms, and the default value for oidc_role_arn is 10000ms.
ConnectTimeout *int `json:"connect_timeout"`

Proxy *string `json:"proxy"`
InAdvanceScale *float64 `json:"inAdvanceScale"`
}

func (s Config) String() string {
Expand Down Expand Up @@ -343,7 +361,7 @@ func NewCredential(config *Config) (credential Credential, err error) {
}
credential = newBearerTokenCredential(tea.StringValue(config.BearerToken))
default:
err = errors.New("invalid type option, support: access_key, sts, ecs_ram_role, ram_role_arn, rsa_key_pair")
err = errors.New("invalid type option, support: access_key, sts, bearer, ecs_ram_role, ram_role_arn, rsa_key_pair, oidc_role_arn, credentials_uri")
return
}
return credential, nil
Expand Down
6 changes: 3 additions & 3 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ this is privatekey`

func TestConfig(t *testing.T) {
config := new(Config)
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiration\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiration\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"security_token\": null,\n \"bearer_token\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"role_session_expiration\": null,\n \"policy\": null,\n \"external_id\": null,\n \"sts_endpoint\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"url\": null,\n \"session_expiration\": null,\n \"public_key_id\": null,\n \"private_key_file\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"security_token\": null,\n \"bearer_token\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"role_session_expiration\": null,\n \"policy\": null,\n \"external_id\": null,\n \"sts_endpoint\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"disable_imds_v1\": null,\n \"metadata_token_duration\": null,\n \"url\": null,\n \"session_expiration\": null,\n \"public_key_id\": null,\n \"private_key_file\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null\n}", config.GoString())

config.SetSTSEndpoint("sts.cn-hangzhou.aliyuncs.com")
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", *config.STSEndpoint)
Expand Down Expand Up @@ -309,7 +309,7 @@ func TestNewCredentialWithInvalidType(t *testing.T) {
config.SetType("sdk")
cred, err := NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "invalid type option, support: access_key, sts, ecs_ram_role, ram_role_arn, rsa_key_pair", err.Error())
assert.Equal(t, "invalid type option, support: access_key, sts, bearer, ecs_ram_role, ram_role_arn, rsa_key_pair, oidc_role_arn, credentials_uri", err.Error())
assert.Nil(t, cred)
}

Expand Down
7 changes: 7 additions & 0 deletions credentials/internal/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

type Request struct {
Method string // http request method
URL string // http url
Protocol string // http or https
Host string // http host
ReadTimeout time.Duration
Expand All @@ -31,6 +32,9 @@ type Request struct {

func (req *Request) BuildRequestURL() string {
httpUrl := fmt.Sprintf("%s://%s%s", req.Protocol, req.Host, req.Path)
if req.URL != "" {
httpUrl = req.URL
}

querystring := utils.GetURLFormedMap(req.Queries)
if querystring != "" {
Expand Down Expand Up @@ -60,6 +64,9 @@ func Do(req *Request) (res *Response, err error) {
querystring := utils.GetURLFormedMap(req.Queries)
// do request
httpUrl := fmt.Sprintf("%s://%s%s?%s", req.Protocol, req.Host, req.Path, querystring)
if req.URL != "" {
httpUrl = req.URL
}

var body io.Reader
if req.Method == "GET" {
Expand Down
14 changes: 9 additions & 5 deletions credentials/providers/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@ func NewDefaultCredentialsProvider() (provider *DefaultCredentialsProvider) {
}

// Add IMDS
if os.Getenv("ALIBABA_CLOUD_ECS_METADATA") != "" {
ecsRamRoleProvider, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName(os.Getenv("ALIBABA_CLOUD_ECS_METADATA")).Build()
ecsRamRoleProvider, err := NewECSRAMRoleCredentialsProviderBuilder().Build()
if err == nil {
providers = append(providers, ecsRamRoleProvider)
}

// credentials uri
if os.Getenv("ALIBABA_CLOUD_CREDENTIALS_URI") != "" {
credentialsUriProvider, err := NewURLCredentialsProviderBuilderBuilder().Build()
if err == nil {
providers = append(providers, ecsRamRoleProvider)
providers = append(providers, credentialsUriProvider)
}
}

// TODO: ALIBABA_CLOUD_CREDENTIALS_URI check

return &DefaultCredentialsProvider{
providerChain: providers,
}
Expand Down
42 changes: 37 additions & 5 deletions credentials/providers/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
func TestDefaultCredentialsProvider(t *testing.T) {
provider := NewDefaultCredentialsProvider()
assert.NotNil(t, provider)
assert.Len(t, provider.providerChain, 3)
assert.Len(t, provider.providerChain, 4)
_, ok := provider.providerChain[0].(*EnvironmentVariableCredentialsProvider)
assert.True(t, ok)

Expand All @@ -21,11 +21,15 @@ func TestDefaultCredentialsProvider(t *testing.T) {
_, ok = provider.providerChain[2].(*ProfileCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[3].(*ECSRAMRoleCredentialsProvider)
assert.True(t, ok)

// Add oidc provider
rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE",
"ALIBABA_CLOUD_OIDC_PROVIDER_ARN",
"ALIBABA_CLOUD_ROLE_ARN",
"ALIBABA_CLOUD_ECS_METADATA")
"ALIBABA_CLOUD_ECS_METADATA",
"ALIBABA_CLOUD_CREDENTIALS_URI")

defer rollback()
os.Setenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "/path/to/oidc.token")
Expand All @@ -34,7 +38,7 @@ func TestDefaultCredentialsProvider(t *testing.T) {

provider = NewDefaultCredentialsProvider()
assert.NotNil(t, provider)
assert.Len(t, provider.providerChain, 4)
assert.Len(t, provider.providerChain, 5)
_, ok = provider.providerChain[0].(*EnvironmentVariableCredentialsProvider)
assert.True(t, ok)

Expand All @@ -47,7 +51,10 @@ func TestDefaultCredentialsProvider(t *testing.T) {
_, ok = provider.providerChain[3].(*ProfileCredentialsProvider)
assert.True(t, ok)

// Add ecs ram role
_, ok = provider.providerChain[4].(*ECSRAMRoleCredentialsProvider)
assert.True(t, ok)

// Add ecs ram role name
os.Setenv("ALIBABA_CLOUD_ECS_METADATA", "rolename")
provider = NewDefaultCredentialsProvider()
assert.NotNil(t, provider)
Expand All @@ -66,12 +73,36 @@ func TestDefaultCredentialsProvider(t *testing.T) {

_, ok = provider.providerChain[4].(*ECSRAMRoleCredentialsProvider)
assert.True(t, ok)

// Add ecs ram role
os.Setenv("ALIBABA_CLOUD_CREDENTIALS_URI", "http://")
provider = NewDefaultCredentialsProvider()
assert.NotNil(t, provider)
assert.Len(t, provider.providerChain, 6)
_, ok = provider.providerChain[0].(*EnvironmentVariableCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[1].(*OIDCCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[2].(*CLIProfileCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[3].(*ProfileCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[4].(*ECSRAMRoleCredentialsProvider)
assert.True(t, ok)

_, ok = provider.providerChain[5].(*URLCredentialsProvider)
assert.True(t, ok)
}

func TestDefaultCredentialsProvider_GetCredentials(t *testing.T) {
rollback := utils.Memory("ALIBABA_CLOUD_ACCESS_KEY_ID",
"ALIBABA_CLOUD_ACCESS_KEY_SECRET",
"ALIBABA_CLOUD_SECURITY_TOKEN")
"ALIBABA_CLOUD_SECURITY_TOKEN",
"ALIBABA_CLOUD_ECS_METADATA_DISABLED")

defer func() {
getHomePath = utils.GetHomePath
Expand All @@ -83,6 +114,7 @@ func TestDefaultCredentialsProvider_GetCredentials(t *testing.T) {
return ""
}

os.Setenv("ALIBABA_CLOUD_ECS_METADATA_DISABLED", "true")
provider := NewDefaultCredentialsProvider()
assert.Len(t, provider.providerChain, 3)
_, err := provider.GetCredentials()
Expand Down
85 changes: 69 additions & 16 deletions credentials/providers/ecs_ram_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package providers

import (
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
Expand All @@ -17,6 +18,8 @@ type ECSRAMRoleCredentialsProvider struct {
// for sts
session *sessionCredentials
expirationTimestamp int64
// for http options
httpOptions *HttpOptions
}

type ECSRAMRoleCredentialsProviderBuilder struct {
Expand All @@ -39,10 +42,20 @@ func (builder *ECSRAMRoleCredentialsProviderBuilder) WithDisableIMDSv1(disableIM
return builder
}

func (builder *ECSRAMRoleCredentialsProviderBuilder) WithHttpOptions(httpOptions *HttpOptions) *ECSRAMRoleCredentialsProviderBuilder {
builder.provider.httpOptions = httpOptions
return builder
}

const defaultMetadataTokenDuration = 21600 // 6 hours

func (builder *ECSRAMRoleCredentialsProviderBuilder) Build() (provider *ECSRAMRoleCredentialsProvider, err error) {

if strings.ToLower(os.Getenv("ALIBABA_CLOUD_ECS_METADATA_DISABLED")) == "true" {
err = errors.New("IMDS credentials is disabled")
return
}

// 设置 roleName 默认值
if builder.provider.roleName == "" {
builder.provider.roleName = os.Getenv("ALIBABA_CLOUD_ECS_METADATA")
Expand Down Expand Up @@ -75,14 +88,27 @@ func (provider *ECSRAMRoleCredentialsProvider) needUpdateCredential() bool {

func (provider *ECSRAMRoleCredentialsProvider) getRoleName() (roleName string, err error) {
req := &httputil.Request{
Method: "GET",
Protocol: "http",
Host: "100.100.100.200",
Path: "/latest/meta-data/ram/security-credentials/",
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
Headers: map[string]string{},
Method: "GET",
Protocol: "http",
Host: "100.100.100.200",
Path: "/latest/meta-data/ram/security-credentials/",
Headers: map[string]string{},
}

connectTimeout := 1 * time.Second
readTimeout := 1 * time.Second

if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
}
if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
}
if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
req.Proxy = provider.httpOptions.Proxy
}
req.ConnectTimeout = connectTimeout
req.ReadTimeout = readTimeout

metadataToken, err := provider.getMetadataToken()
if err != nil {
Expand Down Expand Up @@ -117,14 +143,27 @@ func (provider *ECSRAMRoleCredentialsProvider) getCredentials() (session *sessio
}

req := &httputil.Request{
Method: "GET",
Protocol: "http",
Host: "100.100.100.200",
Path: "/latest/meta-data/ram/security-credentials/" + roleName,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
Headers: map[string]string{},
Method: "GET",
Protocol: "http",
Host: "100.100.100.200",
Path: "/latest/meta-data/ram/security-credentials/" + roleName,
Headers: map[string]string{},
}

connectTimeout := 1 * time.Second
readTimeout := 1 * time.Second

if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
}
if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
}
if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
req.Proxy = provider.httpOptions.Proxy
}
req.ConnectTimeout = connectTimeout
req.ReadTimeout = readTimeout

metadataToken, err := provider.getMetadataToken()
if err != nil {
Expand Down Expand Up @@ -209,9 +248,23 @@ func (provider *ECSRAMRoleCredentialsProvider) getMetadataToken() (metadataToken
Headers: map[string]string{
"X-aliyun-ecs-metadata-token-ttl-seconds": strconv.Itoa(defaultMetadataTokenDuration),
},
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
}

connectTimeout := 1 * time.Second
readTimeout := 1 * time.Second

if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
}
if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
}
if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
req.Proxy = provider.httpOptions.Proxy
}
req.ConnectTimeout = connectTimeout
req.ReadTimeout = readTimeout

res, _err := httpDo(req)
if _err != nil {
if provider.disableIMDSv1 {
Expand Down
Loading

0 comments on commit 571bb67

Please sign in to comment.