Skip to content

Commit

Permalink
improve ram role arn credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonTian authored and yndu13 committed Aug 20, 2024
1 parent 09e2c68 commit a4f7a0f
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 64 deletions.
76 changes: 34 additions & 42 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,27 +265,43 @@ func NewCredential(config *Config) (credential Credential, err error) {
tea.Float64Value(config.InAdvanceScale),
runtime)
case "ram_role_arn":
err = checkRAMRoleArn(config)
var credentialsProvider providers.CredentialsProvider
if config.SecurityToken != nil {
credentialsProvider, err = providers.NewStaticSTSCredentialsProviderBuilder().
WithAccessKeyId(tea.StringValue(config.AccessKeyId)).
WithAccessKeySecret(tea.StringValue(config.AccessKeySecret)).
WithSecurityToken(tea.StringValue(config.SecurityToken)).
Build()
} else {
credentialsProvider, err = providers.NewStaticAKCredentialsProviderBuilder().
WithAccessKeyId(tea.StringValue(config.AccessKeyId)).
WithAccessKeySecret(tea.StringValue(config.AccessKeySecret)).
Build()
}

if err != nil {
return
return nil, err
}
runtime := &utils.Runtime{
Host: tea.StringValue(config.Host),
Proxy: tea.StringValue(config.Proxy),
ReadTimeout: tea.IntValue(config.Timeout),
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
STSEndpoint: tea.StringValue(config.STSEndpoint),

provider, err := providers.NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(credentialsProvider).
WithRoleArn(tea.StringValue(config.RoleArn)).
WithRoleSessionName(tea.StringValue(config.RoleSessionName)).
WithPolicy(tea.StringValue(config.Policy)).
WithDurationSeconds(tea.IntValue(config.RoleSessionExpiration)).
WithExternalId(tea.StringValue(config.ExternalId)).
WithStsEndpoint(tea.StringValue(config.STSEndpoint)).
WithHttpOptions(&providers.HttpOptions{
Proxy: tea.StringValue(config.Proxy),
ReadTimeout: tea.IntValue(config.Timeout),
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
}).
Build()
if err != nil {
return nil, err
}
credential = newRAMRoleArnl(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret),
tea.StringValue(config.SecurityToken),
tea.StringValue(config.RoleArn),
tea.StringValue(config.RoleSessionName),
tea.StringValue(config.Policy),
tea.IntValue(config.RoleSessionExpiration),
tea.StringValue(config.ExternalId),
runtime)

credential = fromCredentialsProvider("ram_role_arn", provider)
case "rsa_key_pair":
err = checkRSAKeyPair(config)
if err != nil {
Expand Down Expand Up @@ -354,30 +370,6 @@ func checkoutAssumeRamoidc(config *Config) (err error) {
return
}

func checkRAMRoleArn(config *Config) (err error) {
if tea.StringValue(config.AccessKeyId) == "" {
err = errors.New("AccessKeyId cannot be empty")
return
}

if tea.StringValue(config.AccessKeySecret) == "" {
err = errors.New("AccessKeySecret cannot be empty")
return
}

if tea.StringValue(config.RoleArn) == "" {
err = errors.New("RoleArn cannot be empty")
return
}

if tea.StringValue(config.RoleSessionName) == "" {
err = errors.New("RoleSessionName cannot be empty")
return
}

return
}

func doAction(request *request.CommonRequest, runtime *utils.Runtime) (content []byte, err error) {
var urlEncoded string
if request.BodyParams != nil {
Expand Down
12 changes: 3 additions & 9 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,30 +171,24 @@ func TestNewCredentialWithRAMRoleARN(t *testing.T) {
config.SetAccessKeyId("")
cred, err := NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "AccessKeyId cannot be empty", err.Error())
assert.Equal(t, "the access key id is empty", err.Error())
assert.Nil(t, cred)

config.SetAccessKeyId("akid")
config.SetAccessKeySecret("")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "AccessKeySecret cannot be empty", err.Error())
assert.Equal(t, "the access key secret is empty", err.Error())
assert.Nil(t, cred)

config.SetAccessKeySecret("AccessKeySecret")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "RoleArn cannot be empty", err.Error())
assert.Equal(t, "the RoleArn is empty", err.Error())
assert.Nil(t, cred)

config.SetRoleArn("roleArn")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "RoleSessionName cannot be empty", err.Error())
assert.Nil(t, cred)

config.SetRoleSessionName("RoleSessionName")
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)
}
Expand Down
65 changes: 55 additions & 10 deletions credentials/internal/providers/ram_role_arn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,24 @@ type sessionCredentials struct {
Expiration string
}

type HttpOptions struct {
Proxy string
ConnectTimeout int
ReadTimeout int
}

type RAMRoleARNCredentialsProvider struct {
credentialsProvider CredentialsProvider
roleArn string
roleSessionName string
durationSeconds int
policy string
stsRegion string
externalId string
// for sts endpoint
stsRegionId string
stsEndpoint string
// for http options
httpOptions *HttpOptions
// inner
expirationTimestamp int64
lastUpdateTimestamp int64
Expand All @@ -71,8 +81,13 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) WithRoleArn(roleArn string)
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsRegion(regionId string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.stsRegion = regionId
func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsRegionId(regionId string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.stsRegionId = regionId
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsEndpoint(endpoint string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.stsEndpoint = endpoint
return builder
}

Expand All @@ -96,6 +111,11 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) WithDurationSeconds(duratio
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithHttpOptions(httpOptions *HttpOptions) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.httpOptions = httpOptions
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleARNCredentialsProvider, err error) {
if builder.provider.credentialsProvider == nil {
err = errors.New("must specify a previous credentials provider to asssume role")
Expand All @@ -122,19 +142,22 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleA
return
}

// sts endpoint
if builder.provider.stsEndpoint == "" {
if builder.provider.stsRegionId != "" {
builder.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", builder.provider.stsRegionId)
} else {
builder.provider.stsEndpoint = "sts.aliyuncs.com"
}
}

provider = builder.provider
return
}

func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) (session *sessionCredentials, err error) {
method := "POST"
var host string
if provider.stsRegion != "" {
host = fmt.Sprintf("sts.%s.aliyuncs.com", provider.stsRegion)
} else {
host = "sts.aliyuncs.com"
}

host := provider.stsEndpoint
queries := make(map[string]string)
queries["Version"] = "2015-04-01"
queries["Action"] = "AssumeRole"
Expand Down Expand Up @@ -194,6 +217,23 @@ func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) (
httpRequest.Header["x-credentials-provider"] = []string{cc.ProviderName}
httpClient := &http.Client{}

if provider.httpOptions != nil {
httpClient.Timeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Second
proxy := &url.URL{}
if provider.httpOptions.Proxy != "" {
proxy, err = url.Parse(provider.httpOptions.Proxy)
if err != nil {
return
}
}
trans := &http.Transport{}
if proxy != nil && provider.httpOptions.Proxy != "" {
trans.Proxy = http.ProxyURL(proxy)
}
trans.DialContext = utils.Timeout(time.Duration(provider.httpOptions.ConnectTimeout) * time.Second)
httpClient.Transport = trans
}

httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
Expand Down Expand Up @@ -269,6 +309,11 @@ func (provider *RAMRoleARNCredentialsProvider) GetCredentials() (cc *Credentials
AccessKeyId: provider.sessionCredentials.AccessKeyId,
AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
SecurityToken: provider.sessionCredentials.SecurityToken,
ProviderName: fmt.Sprintf("%s/%s", provider.GetProviderName(), provider.credentialsProvider.GetProviderName()),
}
return
}

func (provider *RAMRoleARNCredentialsProvider) GetProviderName() string {
return "ram_role_arn"
}
54 changes: 51 additions & 3 deletions credentials/internal/providers/ram_role_arn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
WithStsRegion("cn-hangzhou").
WithStsRegionId("cn-hangzhou").
WithPolicy("policy").
WithExternalId("externalId").
WithRoleSessionName("rsn").
Expand All @@ -76,8 +76,47 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
assert.Equal(t, "roleArn", p.roleArn)
assert.Equal(t, "policy", p.policy)
assert.Equal(t, "externalId", p.externalId)
assert.Equal(t, "cn-hangzhou", p.stsRegion)
assert.Equal(t, "cn-hangzhou", p.stsRegionId)
assert.Equal(t, 1000, p.durationSeconds)
// sts endpoint with sts region
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint)

// sts endpoint with sts endpoint
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
WithStsEndpoint("sts.cn-shanghai.aliyuncs.com").
WithPolicy("policy").
WithExternalId("externalId").
WithRoleSessionName("rsn").
WithDurationSeconds(1000).
Build()
assert.Nil(t, err)
assert.Equal(t, "rsn", p.roleSessionName)
assert.Equal(t, "roleArn", p.roleArn)
assert.Equal(t, "policy", p.policy)
assert.Equal(t, "externalId", p.externalId)
assert.Equal(t, "", p.stsRegionId)
assert.Equal(t, 1000, p.durationSeconds)
assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint)

// default sts endpoint
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
WithPolicy("policy").
WithExternalId("externalId").
WithRoleSessionName("rsn").
WithDurationSeconds(1000).
Build()
assert.Nil(t, err)
assert.Equal(t, "rsn", p.roleSessionName)
assert.Equal(t, "roleArn", p.roleArn)
assert.Equal(t, "policy", p.policy)
assert.Equal(t, "externalId", p.externalId)
assert.Equal(t, "", p.stsRegionId)
assert.Equal(t, 1000, p.durationSeconds)
assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint)
}

func TestRAMRoleARNCredentialsProvider_getCredentials(t *testing.T) {
Expand Down Expand Up @@ -228,7 +267,7 @@ func TestRAMRoleARNCredentialsProvider_getCredentialsWithRequestCheck(t *testing
WithRoleSessionName("rsn").
WithDurationSeconds(1000).
WithPolicy("policy").
WithStsRegion("cn-beijing").
WithStsRegionId("cn-beijing").
WithExternalId("externalId").
Build()
assert.Nil(t, err)
Expand Down Expand Up @@ -333,6 +372,15 @@ func TestRAMRoleARNCredentialsProviderGetCredentials(t *testing.T) {
assert.Equal(t, "akid", cc.AccessKeyId)
assert.Equal(t, "aksecret", cc.AccessKeySecret)
assert.Equal(t, "ststoken", cc.SecurityToken)
assert.Equal(t, "ram_role_arn/static_ak", cc.ProviderName)
assert.True(t, p.needUpdateCredential())
// get credentials again
cc, err = p.GetCredentials()
assert.Nil(t, err)
assert.Equal(t, "akid", cc.AccessKeyId)
assert.Equal(t, "aksecret", cc.AccessKeySecret)
assert.Equal(t, "ststoken", cc.SecurityToken)
assert.Equal(t, "ram_role_arn/static_ak", cc.ProviderName)
assert.True(t, p.needUpdateCredential())
}

Expand Down

0 comments on commit a4f7a0f

Please sign in to comment.