Skip to content

Commit

Permalink
refine RAM role arn credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonTian committed Aug 19, 2024
1 parent cd6b32d commit 8669e2e
Show file tree
Hide file tree
Showing 5 changed files with 772 additions and 2 deletions.
22 changes: 22 additions & 0 deletions credentials/internal/providers/hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package providers

import (
"io"
"net/http"
)

type newReuqest func(method, url string, body io.Reader) (*http.Request, error)

var hookNewRequest = func(fn newReuqest) newReuqest {
return fn
}

type do func(req *http.Request) (*http.Response, error)

var hookDo = func(fn do) do {
return fn
}

var hookParse = func(err error) error {
return err
}
90 changes: 90 additions & 0 deletions credentials/internal/providers/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package providers

import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/alibabacloud-go/debug/debug"
"github.com/aliyun/credentials-go/credentials/internal/utils"
"github.com/aliyun/credentials-go/credentials/request"
"github.com/aliyun/credentials-go/credentials/response"
)

var debuglog = debug.Init("credential")

func mockResponse(statusCode int, content string) (res *http.Response) {
status := strconv.Itoa(statusCode)
res = &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
Header: map[string][]string{"sdk": {"test"}},
StatusCode: statusCode,
Status: status + " " + http.StatusText(statusCode),
}
res.Body = ioutil.NopCloser(bytes.NewReader([]byte(content)))
return
}

func doAction(request *request.CommonRequest, runtime *utils.Runtime) (content []byte, err error) {
var urlEncoded string
if request.BodyParams != nil {
urlEncoded = utils.GetURLFormedMap(request.BodyParams)
}
httpRequest, err := http.NewRequest(request.Method, request.URL, strings.NewReader(urlEncoded))
if err != nil {
return
}
httpRequest.Proto = "HTTP/1.1"
httpRequest.Host = request.Domain
debuglog("> %s %s %s", httpRequest.Method, httpRequest.URL.RequestURI(), httpRequest.Proto)
debuglog("> Host: %s", httpRequest.Host)
for key, value := range request.Headers {
if value != "" {
debuglog("> %s: %s", key, value)
httpRequest.Header[key] = []string{value}
}
}
debuglog(">")
httpClient := &http.Client{}
httpClient.Timeout = time.Duration(runtime.ReadTimeout) * time.Second
proxy := &url.URL{}
if runtime.Proxy != "" {
proxy, err = url.Parse(runtime.Proxy)
if err != nil {
return
}
}
trans := &http.Transport{}
if proxy != nil && runtime.Proxy != "" {
trans.Proxy = http.ProxyURL(proxy)
}
trans.DialContext = utils.Timeout(time.Duration(runtime.ConnectTimeout) * time.Second)
httpClient.Transport = trans
httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
}
debuglog("< %s %s", httpResponse.Proto, httpResponse.Status)
for key, value := range httpResponse.Header {
debuglog("< %s: %v", key, strings.Join(value, ""))
}
debuglog("<")

resp := &response.CommonResponse{}
err = hookParse(resp.ParseFromHTTPResponse(httpResponse))
if err != nil {
return
}
debuglog("%s", resp.GetHTTPContentString())
if resp.GetHTTPStatus() != http.StatusOK {
err = fmt.Errorf("httpStatus: %d, message = %s", resp.GetHTTPStatus(), resp.GetHTTPContentString())
return
}
return resp.GetHTTPContentBytes(), nil
}
274 changes: 274 additions & 0 deletions credentials/internal/providers/ram_role_arn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
package providers

import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/aliyun/credentials-go/credentials/internal/utils"
)

type assumedRoleUser struct {
}

type credentials struct {
SecurityToken *string `json:"SecurityToken"`
Expiration *string `json:"Expiration"`
AccessKeySecret *string `json:"AccessKeySecret"`
AccessKeyId *string `json:"AccessKeyId"`
}

type assumeRoleResponse struct {
RequestID *string `json:"RequestId"`
AssumedRoleUser *assumedRoleUser `json:"AssumedRoleUser"`
Credentials *credentials `json:"Credentials"`
}

type sessionCredentials struct {
AccessKeyId string
AccessKeySecret string
SecurityToken string
Expiration string
}

type RAMRoleARNCredentialsProvider struct {
credentialsProvider CredentialsProvider
roleArn string
roleSessionName string
durationSeconds int
policy string
stsRegion string
externalId string
// inner
expirationTimestamp int64
lastUpdateTimestamp int64
sessionCredentials *sessionCredentials
}

type RAMRoleARNCredentialsProviderBuilder struct {
provider *RAMRoleARNCredentialsProvider
}

func NewRAMRoleARNCredentialsProviderBuilder() *RAMRoleARNCredentialsProviderBuilder {
return &RAMRoleARNCredentialsProviderBuilder{
provider: &RAMRoleARNCredentialsProvider{},
}
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithCredentialsProvider(credentialsProvider CredentialsProvider) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.credentialsProvider = credentialsProvider
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithRoleArn(roleArn string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.roleArn = roleArn
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsRegion(regionId string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.stsRegion = regionId
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.roleSessionName = roleSessionName
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithPolicy(policy string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.policy = policy
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithExternalId(externalId string) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.externalId = externalId
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *RAMRoleARNCredentialsProviderBuilder {
builder.provider.durationSeconds = durationSeconds
return builder
}

func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleARNCredentialsProvider, err error) {
if builder.provider.credentialsProvider == nil {
err = errors.New("must specify a previous credentials provider to asssume role")
return
}

if builder.provider.roleArn == "" {
err = errors.New("the RoleArn is empty")
return
}

if builder.provider.roleSessionName == "" {
builder.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10)
}

// duration seconds
if builder.provider.durationSeconds == 0 {
// default to 3600
builder.provider.durationSeconds = 3600
}

if builder.provider.durationSeconds < 900 {
err = errors.New("session duration should be in the range of 900s - max session duration")
return
}

provider = builder.provider
return
}

func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) (session *sessionCredentials, err error) {
method := "POST"
var host string
if provider.stsRegion != "" {
host = fmt.Sprintf("sts.%s.aliyuncs.com", provider.stsRegion)
} else {
host = "sts.aliyuncs.com"
}

queries := make(map[string]string)
queries["Version"] = "2015-04-01"
queries["Action"] = "AssumeRole"
queries["Format"] = "JSON"
queries["Timestamp"] = utils.GetTimeInFormatISO8601()
queries["SignatureMethod"] = "HMAC-SHA1"
queries["SignatureVersion"] = "1.0"
queries["SignatureNonce"] = utils.GetNonce()
queries["AccessKeyId"] = cc.AccessKeyId
if cc.SecurityToken != "" {
queries["SecurityToken"] = cc.SecurityToken
}

bodyForm := make(map[string]string)
bodyForm["RoleArn"] = provider.roleArn
if provider.policy != "" {
bodyForm["Policy"] = provider.policy
}
if provider.externalId != "" {
bodyForm["ExternalId"] = provider.externalId
}
bodyForm["RoleSessionName"] = provider.roleSessionName
bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds)

// caculate signature
signParams := make(map[string]string)
for key, value := range queries {
signParams[key] = value
}
for key, value := range bodyForm {
signParams[key] = value
}

stringToSign := utils.GetURLFormedMap(signParams)
stringToSign = strings.Replace(stringToSign, "+", "%20", -1)
stringToSign = strings.Replace(stringToSign, "*", "%2A", -1)
stringToSign = strings.Replace(stringToSign, "%7E", "~", -1)
stringToSign = url.QueryEscape(stringToSign)
stringToSign = method + "&%2F&" + stringToSign
secret := cc.AccessKeySecret + "&"
queries["Signature"] = utils.ShaHmac1(stringToSign, secret)

querystring := utils.GetURLFormedMap(queries)
// do request
httpUrl := fmt.Sprintf("https://%s/?%s", host, querystring)

body := utils.GetURLFormedMap(bodyForm)

httpRequest, err := hookNewRequest(http.NewRequest)(method, httpUrl, strings.NewReader(body))
if err != nil {
return
}

// set headers
httpRequest.Header["Accept-Encoding"] = []string{"identity"}
httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"}
httpRequest.Header["x-credentials-provider"] = []string{cc.ProviderName}
httpClient := &http.Client{}

httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
}

defer httpResponse.Body.Close()

responseBody, err := ioutil.ReadAll(httpResponse.Body)
if err != nil {
return
}

if httpResponse.StatusCode != http.StatusOK {
err = errors.New("refresh session token failed: " + string(responseBody))
return
}
var data assumeRoleResponse
err = json.Unmarshal(responseBody, &data)
if err != nil {
err = fmt.Errorf("refresh RoleArn sts token err, json.Unmarshal fail: %s", err.Error())
return
}
if data.Credentials == nil {
err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
return
}

if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil {
err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
return
}

session = &sessionCredentials{
AccessKeyId: *data.Credentials.AccessKeyId,
AccessKeySecret: *data.Credentials.AccessKeySecret,
SecurityToken: *data.Credentials.SecurityToken,
Expiration: *data.Credentials.Expiration,
}
return
}

func (provider *RAMRoleARNCredentialsProvider) needUpdateCredential() (result bool) {
if provider.expirationTimestamp == 0 {
return true
}

return provider.expirationTimestamp-time.Now().Unix() <= 180
}

func (provider *RAMRoleARNCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
if provider.sessionCredentials == nil || provider.needUpdateCredential() {
// 获取前置凭证
previousCredentials, err1 := provider.credentialsProvider.GetCredentials()
if err1 != nil {
return nil, err1
}
sessionCredentials, err2 := provider.getCredentials(previousCredentials)
if err2 != nil {
return nil, err2
}

expirationTime, err := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration)
if err != nil {
return nil, err
}

provider.expirationTimestamp = expirationTime.Unix()
provider.lastUpdateTimestamp = time.Now().Unix()
provider.sessionCredentials = sessionCredentials
}

cc = &Credentials{
AccessKeyId: provider.sessionCredentials.AccessKeyId,
AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
SecurityToken: provider.sessionCredentials.SecurityToken,
}
return
}
Loading

0 comments on commit 8669e2e

Please sign in to comment.