Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine the OIDC credentials provider #100

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 14 additions & 28 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,30 +207,28 @@ func NewCredential(config *Config) (credential Credential, err error) {
case "credentials_uri":
credential = newURLCredential(tea.StringValue(config.Url))
case "oidc_role_arn":
err = checkoutAssumeRamoidc(config)
if err != nil {
return
}
runtime := &utils.Runtime{
Host: tea.StringValue(config.Host),
Proxy: tea.StringValue(config.Proxy),
ReadTimeout: tea.IntValue(config.Timeout),
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
STSEndpoint: tea.StringValue(config.STSEndpoint),
}
credential, err = newOIDCRoleArnCredential(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret),
tea.StringValue(config.RoleArn),
tea.StringValue(config.OIDCProviderArn),
tea.StringValue(config.OIDCTokenFilePath),
tea.StringValue(config.RoleSessionName),
tea.StringValue(config.Policy),
tea.IntValue(config.RoleSessionExpiration),
runtime)

provider, err := providers.NewOIDCCredentialsProviderBuilder().
WithRoleArn(tea.StringValue(config.RoleArn)).
WithOIDCTokenFilePath(tea.StringValue(config.OIDCTokenFilePath)).
WithOIDCProviderARN(tea.StringValue(config.OIDCProviderArn)).
WithDurationSeconds(tea.IntValue(config.RoleSessionExpiration)).
WithPolicy(tea.StringValue(config.Policy)).
WithRoleSessionName(tea.StringValue(config.RoleSessionName)).
WithSTSEndpoint(tea.StringValue(config.STSEndpoint)).
WithRuntime(runtime).
Build()

if err != nil {
return
return nil, err
}
credential = fromCredentialsProvider("oidc_role_arn", provider)
case "access_key":
provider, err := providers.NewStaticAKCredentialsProviderBuilder().
WithAccessKeyId(tea.StringValue(config.AccessKeyId)).
Expand Down Expand Up @@ -358,18 +356,6 @@ func checkRSAKeyPair(config *Config) (err error) {
return
}

func checkoutAssumeRamoidc(config *Config) (err error) {
if tea.StringValue(config.RoleArn) == "" {
err = errors.New("RoleArn cannot be empty")
return
}
if tea.StringValue(config.OIDCProviderArn) == "" {
err = errors.New("OIDCProviderArn cannot be empty")
return
}
return
}

func doAction(request *request.CommonRequest, runtime *utils.Runtime) (content []byte, err error) {
var urlEncoded string
if request.BodyParams != nil {
Expand Down
15 changes: 6 additions & 9 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,25 +213,22 @@ func TestNewCredentialWithOIDC(t *testing.T) {
config.SetType("oidc_role_arn")
cred, err := NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "RoleArn cannot be empty", err.Error())
assert.Equal(t, "the OIDCTokenFilePath is empty", err.Error())
assert.Nil(t, cred)

config.SetRoleArn("role_arn")
config.SetOIDCTokenFilePath("oidc_token_file_path_test")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "OIDCProviderArn cannot be empty", err.Error())
assert.Equal(t, "the OIDCProviderARN is empty", err.Error())
assert.Nil(t, cred)

config.SetOIDCProviderArn("oidc_provider_arn_test").
SetRoleArn("role_arn_test")
config.SetOIDCProviderArn("oidc_provider_arn_test")
cred, err = NewCredential(config)
assert.NotNil(t, err)
assert.Equal(t, "the OIDC token file path is empty", err.Error())
assert.Equal(t, "the RoleArn is empty", err.Error())
assert.Nil(t, cred)

config.SetOIDCProviderArn("oidc_provider_arn_test").
SetOIDCTokenFilePath("oidc_token_file_path_test").
SetRoleArn("role_arn_test")
config.SetRoleArn("role_arn_test")
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)
Expand Down
1 change: 1 addition & 0 deletions credentials/internal/providers/fixtures/mock_oidctoken
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mock oidc token
267 changes: 267 additions & 0 deletions credentials/internal/providers/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
package providers

import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
"time"

"github.com/aliyun/credentials-go/credentials/internal/utils"
)

type OIDCCredentialsProvider struct {
oidcProviderARN string
oidcTokenFilePath string
roleArn string
roleSessionName string
durationSeconds int
policy string
stsRegionId string
stsEndpoint string
lastUpdateTimestamp int64
expirationTimestamp int64
sessionCredentials *sessionCredentials
runtime *utils.Runtime
}

type OIDCCredentialsProviderBuilder struct {
provider *OIDCCredentialsProvider
}

func NewOIDCCredentialsProviderBuilder() *OIDCCredentialsProviderBuilder {
return &OIDCCredentialsProviderBuilder{
provider: &OIDCCredentialsProvider{},
}
}

func (b *OIDCCredentialsProviderBuilder) WithOIDCProviderARN(oidcProviderArn string) *OIDCCredentialsProviderBuilder {
b.provider.oidcProviderARN = oidcProviderArn
return b
}

func (b *OIDCCredentialsProviderBuilder) WithOIDCTokenFilePath(oidcTokenFilePath string) *OIDCCredentialsProviderBuilder {
b.provider.oidcTokenFilePath = oidcTokenFilePath
return b
}

func (b *OIDCCredentialsProviderBuilder) WithRoleArn(roleArn string) *OIDCCredentialsProviderBuilder {
b.provider.roleArn = roleArn
return b
}

func (b *OIDCCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *OIDCCredentialsProviderBuilder {
b.provider.roleSessionName = roleSessionName
return b
}

func (b *OIDCCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *OIDCCredentialsProviderBuilder {
b.provider.durationSeconds = durationSeconds
return b
}

func (b *OIDCCredentialsProviderBuilder) WithStsRegionId(regionId string) *OIDCCredentialsProviderBuilder {
b.provider.stsRegionId = regionId
return b
}

func (b *OIDCCredentialsProviderBuilder) WithPolicy(policy string) *OIDCCredentialsProviderBuilder {
b.provider.policy = policy
return b
}

func (b *OIDCCredentialsProviderBuilder) WithSTSEndpoint(stsEndpoint string) *OIDCCredentialsProviderBuilder {
b.provider.stsEndpoint = stsEndpoint
return b
}

func (b *OIDCCredentialsProviderBuilder) WithRuntime(runtime *utils.Runtime) *OIDCCredentialsProviderBuilder {
b.provider.runtime = runtime
return b
}

func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvider, err error) {
if b.provider.roleSessionName == "" {
b.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10)
}

if b.provider.oidcTokenFilePath == "" {
b.provider.oidcTokenFilePath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE")
}

if b.provider.oidcTokenFilePath == "" {
err = errors.New("the OIDCTokenFilePath is empty")
return
}

if b.provider.oidcProviderARN == "" {
b.provider.oidcProviderARN = os.Getenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN")
}

if b.provider.oidcProviderARN == "" {
err = errors.New("the OIDCProviderARN is empty")
return
}

if b.provider.roleArn == "" {
b.provider.roleArn = os.Getenv("ALIBABA_CLOUD_ROLE_ARN")
}

if b.provider.roleArn == "" {
err = errors.New("the RoleArn is empty")
return
}

if b.provider.durationSeconds == 0 {
b.provider.durationSeconds = 3600
}

if b.provider.durationSeconds < 900 {
err = errors.New("the Assume Role session duration should be in the range of 15min - max duration seconds")
}

if b.provider.stsEndpoint == "" {
if b.provider.stsRegionId != "" {
b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", b.provider.stsRegionId)
} else {
b.provider.stsEndpoint = "sts.aliyuncs.com"
}
}

provider = b.provider
return
}

func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCredentials, err error) {
method := "POST"
host := provider.stsEndpoint
queries := make(map[string]string)
queries["Version"] = "2015-04-01"
queries["Action"] = "AssumeRoleWithOIDC"
queries["Format"] = "JSON"
queries["Timestamp"] = utils.GetTimeInFormatISO8601()

bodyForm := make(map[string]string)
bodyForm["RoleArn"] = provider.roleArn
bodyForm["OIDCProviderArn"] = provider.oidcProviderARN
token, err := ioutil.ReadFile(provider.oidcTokenFilePath)
if err != nil {
return
}

bodyForm["OIDCToken"] = string(token)
if provider.policy != "" {
bodyForm["Policy"] = provider.policy
}

bodyForm["RoleSessionName"] = provider.roleSessionName
bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds)

// caculate signature
signParams := make(map[string]string)
for key, value := range queries {
signParams[key] = value
}
for key, value := range bodyForm {
signParams[key] = value
}

querystring := utils.GetURLFormedMap(queries)
// do request
httpUrl := fmt.Sprintf("https://%s/?%s", host, querystring)

body := utils.GetURLFormedMap(bodyForm)

httpRequest, err := hookNewRequest(http.NewRequest)(method, httpUrl, strings.NewReader(body))
if err != nil {
return
}

// set headers
httpRequest.Header["Accept-Encoding"] = []string{"identity"}
httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"}
httpClient := &http.Client{}

httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
}

defer httpResponse.Body.Close()

responseBody, err := ioutil.ReadAll(httpResponse.Body)
if err != nil {
return
}

if httpResponse.StatusCode != http.StatusOK {
message := "get session token failed: "
err = errors.New(message + string(responseBody))
return
}
var data assumeRoleResponse
err = json.Unmarshal(responseBody, &data)
if err != nil {
err = fmt.Errorf("get oidc sts token err, json.Unmarshal fail: %s", err.Error())
return
}
if data.Credentials == nil {
err = fmt.Errorf("get oidc sts token err, fail to get credentials")
return
}

if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil {
err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
return
}

session = &sessionCredentials{
AccessKeyId: *data.Credentials.AccessKeyId,
AccessKeySecret: *data.Credentials.AccessKeySecret,
SecurityToken: *data.Credentials.SecurityToken,
Expiration: *data.Credentials.Expiration,
}
return
}

func (provider *OIDCCredentialsProvider) needUpdateCredential() (result bool) {
if provider.expirationTimestamp == 0 {
return true
}

return provider.expirationTimestamp-time.Now().Unix() <= 180
}

func (provider *OIDCCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
if provider.sessionCredentials == nil || provider.needUpdateCredential() {
sessionCredentials, err1 := provider.getCredentials()
if err1 != nil {
return nil, err1
}

provider.sessionCredentials = sessionCredentials
expirationTime, err2 := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration)
if err2 != nil {
return nil, err2
}

provider.lastUpdateTimestamp = time.Now().Unix()
provider.expirationTimestamp = expirationTime.Unix()
}

cc = &Credentials{
AccessKeyId: provider.sessionCredentials.AccessKeyId,
AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
SecurityToken: provider.sessionCredentials.SecurityToken,
ProviderName: provider.GetProviderName(),
}
return
}

func (provider *OIDCCredentialsProvider) GetProviderName() string {
return "oidc"
}
Loading
Loading