Skip to content

Commit

Permalink
refine the OIDC credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonTian committed Aug 20, 2024
1 parent 412f4ed commit f5d62c4
Show file tree
Hide file tree
Showing 7 changed files with 652 additions and 441 deletions.
42 changes: 14 additions & 28 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,30 +207,28 @@ func NewCredential(config *Config) (credential Credential, err error) {
case "credentials_uri":
credential = newURLCredential(tea.StringValue(config.Url))
case "oidc_role_arn":
err = checkoutAssumeRamoidc(config)
if err != nil {
return
}
runtime := &utils.Runtime{
Host: tea.StringValue(config.Host),
Proxy: tea.StringValue(config.Proxy),
ReadTimeout: tea.IntValue(config.Timeout),
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
STSEndpoint: tea.StringValue(config.STSEndpoint),
}
credential, err = newOIDCRoleArnCredential(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret),
tea.StringValue(config.RoleArn),
tea.StringValue(config.OIDCProviderArn),
tea.StringValue(config.OIDCTokenFilePath),
tea.StringValue(config.RoleSessionName),
tea.StringValue(config.Policy),
tea.IntValue(config.RoleSessionExpiration),
runtime)

provider, err := providers.NewOIDCCredentialsProviderBuilder().
WithRoleArn(tea.StringValue(config.RoleArn)).
WithOIDCTokenFilePath(tea.StringValue(config.OIDCTokenFilePath)).
WithOIDCProviderARN(tea.StringValue(config.OIDCProviderArn)).
WithDurationSeconds(tea.IntValue(config.RoleSessionExpiration)).
WithPolicy(tea.StringValue(config.Policy)).
WithRoleSessionName(tea.StringValue(config.RoleSessionName)).
WithSTSEndpoint(tea.StringValue(config.STSEndpoint)).
WithRuntime(runtime).
Build()

if err != nil {
return
return nil, err
}
credential = fromCredentialsProvider("oidc_role_arn", provider)
case "access_key":
provider, err := providers.NewStaticAKCredentialsProviderBuilder().
WithAccessKeyId(tea.StringValue(config.AccessKeyId)).
Expand Down Expand Up @@ -358,18 +356,6 @@ func checkRSAKeyPair(config *Config) (err error) {
return
}

func checkoutAssumeRamoidc(config *Config) (err error) {
if tea.StringValue(config.RoleArn) == "" {
err = errors.New("RoleArn cannot be empty")
return
}
if tea.StringValue(config.OIDCProviderArn) == "" {
err = errors.New("OIDCProviderArn cannot be empty")
return
}
return
}

func doAction(request *request.CommonRequest, runtime *utils.Runtime) (content []byte, err error) {
var urlEncoded string
if request.BodyParams != nil {
Expand Down
15 changes: 6 additions & 9 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,25 +213,22 @@ func TestNewCredentialWithOIDC(t *testing.T) {
config.SetType("oidc_role_arn")
cred, err := NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "RoleArn cannot be empty", err.Error())
assert.Equal(t, "the OIDCTokenFilePath is empty", err.Error())
assert.Nil(t, cred)

config.SetRoleArn("role_arn")
config.SetOIDCTokenFilePath("oidc_token_file_path_test")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "OIDCProviderArn cannot be empty", err.Error())
assert.Equal(t, "the OIDCProviderARN is empty", err.Error())
assert.Nil(t, cred)

config.SetOIDCProviderArn("oidc_provider_arn_test").
SetRoleArn("role_arn_test")
config.SetOIDCProviderArn("oidc_provider_arn_test")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "the OIDC token file path is empty", err.Error())
assert.Equal(t, "the RoleArn is empty", err.Error())
assert.Nil(t, cred)

config.SetOIDCProviderArn("oidc_provider_arn_test").
SetOIDCTokenFilePath("oidc_token_file_path_test").
SetRoleArn("role_arn_test")
config.SetRoleArn("role_arn_test")
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)
Expand Down
1 change: 1 addition & 0 deletions credentials/internal/providers/fixtures/mock_oidctoken
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mock oidc token
267 changes: 267 additions & 0 deletions credentials/internal/providers/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
package providers

import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
"time"

"github.com/aliyun/credentials-go/credentials/internal/utils"
)

type OIDCCredentialsProvider struct {
oidcProviderARN string
oidcTokenFilePath string
roleArn string
roleSessionName string
durationSeconds int
policy string
stsRegionId string
stsEndpoint string
lastUpdateTimestamp int64
expirationTimestamp int64
sessionCredentials *sessionCredentials
runtime *utils.Runtime
}

type OIDCCredentialsProviderBuilder struct {
provider *OIDCCredentialsProvider
}

func NewOIDCCredentialsProviderBuilder() *OIDCCredentialsProviderBuilder {
return &OIDCCredentialsProviderBuilder{
provider: &OIDCCredentialsProvider{},
}
}

func (b *OIDCCredentialsProviderBuilder) WithOIDCProviderARN(oidcProviderArn string) *OIDCCredentialsProviderBuilder {
b.provider.oidcProviderARN = oidcProviderArn
return b
}

func (b *OIDCCredentialsProviderBuilder) WithOIDCTokenFilePath(oidcTokenFilePath string) *OIDCCredentialsProviderBuilder {
b.provider.oidcTokenFilePath = oidcTokenFilePath
return b
}

func (b *OIDCCredentialsProviderBuilder) WithRoleArn(roleArn string) *OIDCCredentialsProviderBuilder {
b.provider.roleArn = roleArn
return b
}

func (b *OIDCCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *OIDCCredentialsProviderBuilder {
b.provider.roleSessionName = roleSessionName
return b
}

func (b *OIDCCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *OIDCCredentialsProviderBuilder {
b.provider.durationSeconds = durationSeconds
return b
}

func (b *OIDCCredentialsProviderBuilder) WithStsRegionId(regionId string) *OIDCCredentialsProviderBuilder {
b.provider.stsRegionId = regionId
return b
}

func (b *OIDCCredentialsProviderBuilder) WithPolicy(policy string) *OIDCCredentialsProviderBuilder {
b.provider.policy = policy
return b
}

func (b *OIDCCredentialsProviderBuilder) WithSTSEndpoint(stsEndpoint string) *OIDCCredentialsProviderBuilder {
b.provider.stsEndpoint = stsEndpoint
return b
}

func (b *OIDCCredentialsProviderBuilder) WithRuntime(runtime *utils.Runtime) *OIDCCredentialsProviderBuilder {
b.provider.runtime = runtime
return b
}

func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvider, err error) {
if b.provider.roleSessionName == "" {
b.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10)
}

if b.provider.oidcTokenFilePath == "" {
b.provider.oidcTokenFilePath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE")
}

if b.provider.oidcTokenFilePath == "" {
err = errors.New("the OIDCTokenFilePath is empty")
return
}

if b.provider.oidcProviderARN == "" {
b.provider.oidcProviderARN = os.Getenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN")
}

if b.provider.oidcProviderARN == "" {
err = errors.New("the OIDCProviderARN is empty")
return
}

if b.provider.roleArn == "" {
b.provider.roleArn = os.Getenv("ALIBABA_CLOUD_ROLE_ARN")
}

if b.provider.roleArn == "" {
err = errors.New("the RoleArn is empty")
return
}

if b.provider.durationSeconds == 0 {
b.provider.durationSeconds = 3600
}

if b.provider.durationSeconds < 900 {
err = errors.New("the Assume Role session duration should be in the range of 15min - max duration seconds")
}

if b.provider.stsEndpoint == "" {
if b.provider.stsRegionId != "" {
b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", b.provider.stsRegionId)
} else {
b.provider.stsEndpoint = "sts.aliyuncs.com"
}
}

provider = b.provider
return
}

func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCredentials, err error) {
method := "POST"
host := provider.stsEndpoint
queries := make(map[string]string)
queries["Version"] = "2015-04-01"
queries["Action"] = "AssumeRoleWithOIDC"
queries["Format"] = "JSON"
queries["Timestamp"] = utils.GetTimeInFormatISO8601()

bodyForm := make(map[string]string)
bodyForm["RoleArn"] = provider.roleArn
bodyForm["OIDCProviderArn"] = provider.oidcProviderARN
token, err := ioutil.ReadFile(provider.oidcTokenFilePath)
if err != nil {
return
}

bodyForm["OIDCToken"] = string(token)
if provider.policy != "" {
bodyForm["Policy"] = provider.policy
}

bodyForm["RoleSessionName"] = provider.roleSessionName
bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds)

// caculate signature
signParams := make(map[string]string)
for key, value := range queries {
signParams[key] = value
}
for key, value := range bodyForm {
signParams[key] = value
}

querystring := utils.GetURLFormedMap(queries)
// do request
httpUrl := fmt.Sprintf("https://%s/?%s", host, querystring)

body := utils.GetURLFormedMap(bodyForm)

httpRequest, err := hookNewRequest(http.NewRequest)(method, httpUrl, strings.NewReader(body))
if err != nil {
return
}

// set headers
httpRequest.Header["Accept-Encoding"] = []string{"identity"}
httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"}
httpClient := &http.Client{}

httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
}

defer httpResponse.Body.Close()

responseBody, err := ioutil.ReadAll(httpResponse.Body)
if err != nil {
return
}

if httpResponse.StatusCode != http.StatusOK {
message := "get session token failed: "
err = errors.New(message + string(responseBody))
return
}
var data assumeRoleResponse
err = json.Unmarshal(responseBody, &data)
if err != nil {
err = fmt.Errorf("get oidc sts token err, json.Unmarshal fail: %s", err.Error())
return
}
if data.Credentials == nil {
err = fmt.Errorf("get oidc sts token err, fail to get credentials")
return
}

if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil {
err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
return
}

session = &sessionCredentials{
AccessKeyId: *data.Credentials.AccessKeyId,
AccessKeySecret: *data.Credentials.AccessKeySecret,
SecurityToken: *data.Credentials.SecurityToken,
Expiration: *data.Credentials.Expiration,
}
return
}

func (provider *OIDCCredentialsProvider) needUpdateCredential() (result bool) {
if provider.expirationTimestamp == 0 {
return true
}

return provider.expirationTimestamp-time.Now().Unix() <= 180
}

func (provider *OIDCCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
if provider.sessionCredentials == nil || provider.needUpdateCredential() {
sessionCredentials, err1 := provider.getCredentials()
if err1 != nil {
return nil, err1
}

provider.sessionCredentials = sessionCredentials
expirationTime, err2 := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration)
if err2 != nil {
return nil, err2
}

provider.lastUpdateTimestamp = time.Now().Unix()
provider.expirationTimestamp = expirationTime.Unix()
}

cc = &Credentials{
AccessKeyId: provider.sessionCredentials.AccessKeyId,
AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
SecurityToken: provider.sessionCredentials.SecurityToken,
ProviderName: provider.GetProviderName(),
}
return
}

func (provider *OIDCCredentialsProvider) GetProviderName() string {
return "oidc"
}
Loading

0 comments on commit f5d62c4

Please sign in to comment.