Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve ram role arn credentials provider #96

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 34 additions & 42 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,27 +265,43 @@ func NewCredential(config *Config) (credential Credential, err error) {
tea.Float64Value(config.InAdvanceScale),
runtime)
case "ram_role_arn":
err = checkRAMRoleArn(config)
var credentialsProvider providers.CredentialsProvider
if config.SecurityToken != nil {
credentialsProvider, err = providers.NewStaticSTSCredentialsProviderBuilder().
WithAccessKeyId(tea.StringValue(config.AccessKeyId)).
WithAccessKeySecret(tea.StringValue(config.AccessKeySecret)).
WithSecurityToken(tea.StringValue(config.SecurityToken)).
Build()
} else {
credentialsProvider, err = providers.NewStaticAKCredentialsProviderBuilder().
WithAccessKeyId(tea.StringValue(config.AccessKeyId)).
WithAccessKeySecret(tea.StringValue(config.AccessKeySecret)).
Build()
}

if err != nil {
return
return nil, err
}
runtime := &utils.Runtime{
Host: tea.StringValue(config.Host),
Proxy: tea.StringValue(config.Proxy),
ReadTimeout: tea.IntValue(config.Timeout),
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
STSEndpoint: tea.StringValue(config.STSEndpoint),

provider, err := providers.NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(credentialsProvider).
WithRoleArn(tea.StringValue(config.RoleArn)).
WithRoleSessionName(tea.StringValue(config.RoleSessionName)).
WithPolicy(tea.StringValue(config.Policy)).
WithDurationSeconds(tea.IntValue(config.RoleSessionExpiration)).
WithExternalId(tea.StringValue(config.ExternalId)).
WithStsEndpoint(tea.StringValue(config.STSEndpoint)).
WithHttpOptions(&providers.HttpOptions{
Proxy: tea.StringValue(config.Proxy),
ReadTimeout: tea.IntValue(config.Timeout),
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
}).
Build()
if err != nil {
return nil, err
}
credential = newRAMRoleArnl(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret),
tea.StringValue(config.SecurityToken),
tea.StringValue(config.RoleArn),
tea.StringValue(config.RoleSessionName),
tea.StringValue(config.Policy),
tea.IntValue(config.RoleSessionExpiration),
tea.StringValue(config.ExternalId),
runtime)

credential = fromCredentialsProvider("ram_role_arn", provider)
case "rsa_key_pair":
err = checkRSAKeyPair(config)
if err != nil {
Expand Down Expand Up @@ -354,30 +370,6 @@ func checkoutAssumeRamoidc(config *Config) (err error) {
return
}

func checkRAMRoleArn(config *Config) (err error) {
if tea.StringValue(config.AccessKeyId) == "" {
err = errors.New("AccessKeyId cannot be empty")
return
}

if tea.StringValue(config.AccessKeySecret) == "" {
err = errors.New("AccessKeySecret cannot be empty")
return
}

if tea.StringValue(config.RoleArn) == "" {
err = errors.New("RoleArn cannot be empty")
return
}

if tea.StringValue(config.RoleSessionName) == "" {
err = errors.New("RoleSessionName cannot be empty")
return
}

return
}

func doAction(request *request.CommonRequest, runtime *utils.Runtime) (content []byte, err error) {
var urlEncoded string
if request.BodyParams != nil {
Expand Down
12 changes: 3 additions & 9 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,30 +171,24 @@ func TestNewCredentialWithRAMRoleARN(t *testing.T) {
config.SetAccessKeyId("")
cred, err := NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "AccessKeyId cannot be empty", err.Error())
assert.Equal(t, "the access key id is empty", err.Error())
assert.Nil(t, cred)

config.SetAccessKeyId("akid")
config.SetAccessKeySecret("")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "AccessKeySecret cannot be empty", err.Error())
assert.Equal(t, "the access key secret is empty", err.Error())
assert.Nil(t, cred)

config.SetAccessKeySecret("AccessKeySecret")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "RoleArn cannot be empty", err.Error())
assert.Equal(t, "the RoleArn is empty", err.Error())
assert.Nil(t, cred)

config.SetRoleArn("roleArn")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "RoleSessionName cannot be empty", err.Error())
assert.Nil(t, cred)

config.SetRoleSessionName("RoleSessionName")
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)
}
Expand Down
65 changes: 55 additions & 10 deletions credentials/internal/providers/ram_role_arn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,24 @@ type sessionCredentials struct {
Expiration string
}

type HttpOptions struct {
Proxy string
ConnectTimeout int
ReadTimeout int
}

type RAMRoleARNCredentialsProvider struct {
credentialsProvider CredentialsProvider
roleArn string
roleSessionName string
durationSeconds int
policy string
stsRegion string
externalId string
// for sts endpoint
stsRegionId string
stsEndpoint string
// for http options
httpOptions *HttpOptions
// inner
expirationTimestamp int64
lastUpdateTimestamp int64
Expand All @@ -71,8 +81,13 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) WithRoleArn(roleArn string)
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsRegion(regionId string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.stsRegion = regionId
func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsRegionId(regionId string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.stsRegionId = regionId
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsEndpoint(endpoint string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.stsEndpoint = endpoint
return builder
}

Expand All @@ -96,6 +111,11 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) WithDurationSeconds(duratio
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithHttpOptions(httpOptions *HttpOptions) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.httpOptions = httpOptions
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleARNCredentialsProvider, err error) {
if builder.provider.credentialsProvider == nil {
err = errors.New("must specify a previous credentials provider to asssume role")
Expand All @@ -122,19 +142,22 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleA
return
}

// sts endpoint
if builder.provider.stsEndpoint == "" {
if builder.provider.stsRegionId != "" {
builder.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", builder.provider.stsRegionId)
} else {
builder.provider.stsEndpoint = "sts.aliyuncs.com"
}
}

provider = builder.provider
return
}

func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) (session *sessionCredentials, err error) {
method := "POST"
var host string
if provider.stsRegion != "" {
host = fmt.Sprintf("sts.%s.aliyuncs.com", provider.stsRegion)
} else {
host = "sts.aliyuncs.com"
}

host := provider.stsEndpoint
queries := make(map[string]string)
queries["Version"] = "2015-04-01"
queries["Action"] = "AssumeRole"
Expand Down Expand Up @@ -194,6 +217,23 @@ func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) (
httpRequest.Header["x-credentials-provider"] = []string{cc.ProviderName}
httpClient := &http.Client{}

if provider.httpOptions != nil {
httpClient.Timeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Second
proxy := &url.URL{}
if provider.httpOptions.Proxy != "" {
proxy, err = url.Parse(provider.httpOptions.Proxy)
if err != nil {
return
}
}
trans := &http.Transport{}
if proxy != nil && provider.httpOptions.Proxy != "" {
trans.Proxy = http.ProxyURL(proxy)
}
trans.DialContext = utils.Timeout(time.Duration(provider.httpOptions.ConnectTimeout) * time.Second)
httpClient.Transport = trans
}

httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
Expand Down Expand Up @@ -269,6 +309,11 @@ func (provider *RAMRoleARNCredentialsProvider) GetCredentials() (cc *Credentials
AccessKeyId: provider.sessionCredentials.AccessKeyId,
AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
SecurityToken: provider.sessionCredentials.SecurityToken,
ProviderName: fmt.Sprintf("%s/%s", provider.GetProviderName(), provider.credentialsProvider.GetProviderName()),
}
return
}

func (provider *RAMRoleARNCredentialsProvider) GetProviderName() string {
return "ram_role_arn"
}
54 changes: 51 additions & 3 deletions credentials/internal/providers/ram_role_arn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
WithStsRegion("cn-hangzhou").
WithStsRegionId("cn-hangzhou").
WithPolicy("policy").
WithExternalId("externalId").
WithRoleSessionName("rsn").
Expand All @@ -76,8 +76,47 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
assert.Equal(t, "roleArn", p.roleArn)
assert.Equal(t, "policy", p.policy)
assert.Equal(t, "externalId", p.externalId)
assert.Equal(t, "cn-hangzhou", p.stsRegion)
assert.Equal(t, "cn-hangzhou", p.stsRegionId)
assert.Equal(t, 1000, p.durationSeconds)
// sts endpoint with sts region
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint)

// sts endpoint with sts endpoint
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
WithStsEndpoint("sts.cn-shanghai.aliyuncs.com").
WithPolicy("policy").
WithExternalId("externalId").
WithRoleSessionName("rsn").
WithDurationSeconds(1000).
Build()
assert.Nil(t, err)
assert.Equal(t, "rsn", p.roleSessionName)
assert.Equal(t, "roleArn", p.roleArn)
assert.Equal(t, "policy", p.policy)
assert.Equal(t, "externalId", p.externalId)
assert.Equal(t, "", p.stsRegionId)
assert.Equal(t, 1000, p.durationSeconds)
assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint)

// default sts endpoint
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
WithPolicy("policy").
WithExternalId("externalId").
WithRoleSessionName("rsn").
WithDurationSeconds(1000).
Build()
assert.Nil(t, err)
assert.Equal(t, "rsn", p.roleSessionName)
assert.Equal(t, "roleArn", p.roleArn)
assert.Equal(t, "policy", p.policy)
assert.Equal(t, "externalId", p.externalId)
assert.Equal(t, "", p.stsRegionId)
assert.Equal(t, 1000, p.durationSeconds)
assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint)
}

func TestRAMRoleARNCredentialsProvider_getCredentials(t *testing.T) {
Expand Down Expand Up @@ -228,7 +267,7 @@ func TestRAMRoleARNCredentialsProvider_getCredentialsWithRequestCheck(t *testing
WithRoleSessionName("rsn").
WithDurationSeconds(1000).
WithPolicy("policy").
WithStsRegion("cn-beijing").
WithStsRegionId("cn-beijing").
WithExternalId("externalId").
Build()
assert.Nil(t, err)
Expand Down Expand Up @@ -333,6 +372,15 @@ func TestRAMRoleARNCredentialsProviderGetCredentials(t *testing.T) {
assert.Equal(t, "akid", cc.AccessKeyId)
assert.Equal(t, "aksecret", cc.AccessKeySecret)
assert.Equal(t, "ststoken", cc.SecurityToken)
assert.Equal(t, "ram_role_arn/static_ak", cc.ProviderName)
assert.True(t, p.needUpdateCredential())
// get credentials again
cc, err = p.GetCredentials()
assert.Nil(t, err)
assert.Equal(t, "akid", cc.AccessKeyId)
assert.Equal(t, "aksecret", cc.AccessKeySecret)
assert.Equal(t, "ststoken", cc.SecurityToken)
assert.Equal(t, "ram_role_arn/static_ak", cc.ProviderName)
assert.True(t, p.needUpdateCredential())
}

Expand Down
Loading