-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refine RAM role arn credentials provider
- Loading branch information
1 parent
cd6b32d
commit 09e2c68
Showing
5 changed files
with
772 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package providers | ||
|
||
import ( | ||
"io" | ||
"net/http" | ||
) | ||
|
||
type newReuqest func(method, url string, body io.Reader) (*http.Request, error) | ||
|
||
var hookNewRequest = func(fn newReuqest) newReuqest { | ||
return fn | ||
} | ||
|
||
type do func(req *http.Request) (*http.Response, error) | ||
|
||
var hookDo = func(fn do) do { | ||
return fn | ||
} | ||
|
||
var hookParse = func(err error) error { | ||
return err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package providers | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
"net/url" | ||
"strconv" | ||
"strings" | ||
"time" | ||
|
||
"github.com/alibabacloud-go/debug/debug" | ||
"github.com/aliyun/credentials-go/credentials/internal/utils" | ||
"github.com/aliyun/credentials-go/credentials/request" | ||
"github.com/aliyun/credentials-go/credentials/response" | ||
) | ||
|
||
var debuglog = debug.Init("credential") | ||
|
||
func mockResponse(statusCode int, content string) (res *http.Response) { | ||
status := strconv.Itoa(statusCode) | ||
res = &http.Response{ | ||
Proto: "HTTP/1.1", | ||
ProtoMajor: 1, | ||
Header: map[string][]string{"sdk": {"test"}}, | ||
StatusCode: statusCode, | ||
Status: status + " " + http.StatusText(statusCode), | ||
} | ||
res.Body = ioutil.NopCloser(bytes.NewReader([]byte(content))) | ||
return | ||
} | ||
|
||
func doAction(request *request.CommonRequest, runtime *utils.Runtime) (content []byte, err error) { | ||
var urlEncoded string | ||
if request.BodyParams != nil { | ||
urlEncoded = utils.GetURLFormedMap(request.BodyParams) | ||
} | ||
httpRequest, err := http.NewRequest(request.Method, request.URL, strings.NewReader(urlEncoded)) | ||
if err != nil { | ||
return | ||
} | ||
httpRequest.Proto = "HTTP/1.1" | ||
httpRequest.Host = request.Domain | ||
debuglog("> %s %s %s", httpRequest.Method, httpRequest.URL.RequestURI(), httpRequest.Proto) | ||
debuglog("> Host: %s", httpRequest.Host) | ||
for key, value := range request.Headers { | ||
if value != "" { | ||
debuglog("> %s: %s", key, value) | ||
httpRequest.Header[key] = []string{value} | ||
} | ||
} | ||
debuglog(">") | ||
httpClient := &http.Client{} | ||
httpClient.Timeout = time.Duration(runtime.ReadTimeout) * time.Second | ||
proxy := &url.URL{} | ||
if runtime.Proxy != "" { | ||
proxy, err = url.Parse(runtime.Proxy) | ||
if err != nil { | ||
return | ||
} | ||
} | ||
trans := &http.Transport{} | ||
if proxy != nil && runtime.Proxy != "" { | ||
trans.Proxy = http.ProxyURL(proxy) | ||
} | ||
trans.DialContext = utils.Timeout(time.Duration(runtime.ConnectTimeout) * time.Second) | ||
httpClient.Transport = trans | ||
httpResponse, err := hookDo(httpClient.Do)(httpRequest) | ||
if err != nil { | ||
return | ||
} | ||
debuglog("< %s %s", httpResponse.Proto, httpResponse.Status) | ||
for key, value := range httpResponse.Header { | ||
debuglog("< %s: %v", key, strings.Join(value, "")) | ||
} | ||
debuglog("<") | ||
|
||
resp := &response.CommonResponse{} | ||
err = hookParse(resp.ParseFromHTTPResponse(httpResponse)) | ||
if err != nil { | ||
return | ||
} | ||
debuglog("%s", resp.GetHTTPContentString()) | ||
if resp.GetHTTPStatus() != http.StatusOK { | ||
err = fmt.Errorf("httpStatus: %d, message = %s", resp.GetHTTPStatus(), resp.GetHTTPContentString()) | ||
return | ||
} | ||
return resp.GetHTTPContentBytes(), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
package providers | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
"net/url" | ||
"strconv" | ||
"strings" | ||
"time" | ||
|
||
"github.com/aliyun/credentials-go/credentials/internal/utils" | ||
) | ||
|
||
type assumedRoleUser struct { | ||
} | ||
|
||
type credentials struct { | ||
SecurityToken *string `json:"SecurityToken"` | ||
Expiration *string `json:"Expiration"` | ||
AccessKeySecret *string `json:"AccessKeySecret"` | ||
AccessKeyId *string `json:"AccessKeyId"` | ||
} | ||
|
||
type assumeRoleResponse struct { | ||
RequestID *string `json:"RequestId"` | ||
AssumedRoleUser *assumedRoleUser `json:"AssumedRoleUser"` | ||
Credentials *credentials `json:"Credentials"` | ||
} | ||
|
||
type sessionCredentials struct { | ||
AccessKeyId string | ||
AccessKeySecret string | ||
SecurityToken string | ||
Expiration string | ||
} | ||
|
||
type RAMRoleARNCredentialsProvider struct { | ||
credentialsProvider CredentialsProvider | ||
roleArn string | ||
roleSessionName string | ||
durationSeconds int | ||
policy string | ||
stsRegion string | ||
externalId string | ||
// inner | ||
expirationTimestamp int64 | ||
lastUpdateTimestamp int64 | ||
sessionCredentials *sessionCredentials | ||
} | ||
|
||
type RAMRoleARNCredentialsProviderBuilder struct { | ||
provider *RAMRoleARNCredentialsProvider | ||
} | ||
|
||
func NewRAMRoleARNCredentialsProviderBuilder() *RAMRoleARNCredentialsProviderBuilder { | ||
return &RAMRoleARNCredentialsProviderBuilder{ | ||
provider: &RAMRoleARNCredentialsProvider{}, | ||
} | ||
} | ||
|
||
func (builder *RAMRoleARNCredentialsProviderBuilder) WithCredentialsProvider(credentialsProvider CredentialsProvider) *RAMRoleARNCredentialsProviderBuilder { | ||
builder.provider.credentialsProvider = credentialsProvider | ||
return builder | ||
} | ||
|
||
func (builder *RAMRoleARNCredentialsProviderBuilder) WithRoleArn(roleArn string) *RAMRoleARNCredentialsProviderBuilder { | ||
builder.provider.roleArn = roleArn | ||
return builder | ||
} | ||
|
||
func (builder *RAMRoleARNCredentialsProviderBuilder) WithStsRegion(regionId string) *RAMRoleARNCredentialsProviderBuilder { | ||
builder.provider.stsRegion = regionId | ||
return builder | ||
} | ||
|
||
func (builder *RAMRoleARNCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *RAMRoleARNCredentialsProviderBuilder { | ||
builder.provider.roleSessionName = roleSessionName | ||
return builder | ||
} | ||
|
||
func (builder *RAMRoleARNCredentialsProviderBuilder) WithPolicy(policy string) *RAMRoleARNCredentialsProviderBuilder { | ||
builder.provider.policy = policy | ||
return builder | ||
} | ||
|
||
func (builder *RAMRoleARNCredentialsProviderBuilder) WithExternalId(externalId string) *RAMRoleARNCredentialsProviderBuilder { | ||
builder.provider.externalId = externalId | ||
return builder | ||
} | ||
|
||
func (builder *RAMRoleARNCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *RAMRoleARNCredentialsProviderBuilder { | ||
builder.provider.durationSeconds = durationSeconds | ||
return builder | ||
} | ||
|
||
func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleARNCredentialsProvider, err error) { | ||
if builder.provider.credentialsProvider == nil { | ||
err = errors.New("must specify a previous credentials provider to asssume role") | ||
return | ||
} | ||
|
||
if builder.provider.roleArn == "" { | ||
err = errors.New("the RoleArn is empty") | ||
return | ||
} | ||
|
||
if builder.provider.roleSessionName == "" { | ||
builder.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10) | ||
} | ||
|
||
// duration seconds | ||
if builder.provider.durationSeconds == 0 { | ||
// default to 3600 | ||
builder.provider.durationSeconds = 3600 | ||
} | ||
|
||
if builder.provider.durationSeconds < 900 { | ||
err = errors.New("session duration should be in the range of 900s - max session duration") | ||
return | ||
} | ||
|
||
provider = builder.provider | ||
return | ||
} | ||
|
||
func (provider *RAMRoleARNCredentialsProvider) getCredentials(cc *Credentials) (session *sessionCredentials, err error) { | ||
method := "POST" | ||
var host string | ||
if provider.stsRegion != "" { | ||
host = fmt.Sprintf("sts.%s.aliyuncs.com", provider.stsRegion) | ||
} else { | ||
host = "sts.aliyuncs.com" | ||
} | ||
|
||
queries := make(map[string]string) | ||
queries["Version"] = "2015-04-01" | ||
queries["Action"] = "AssumeRole" | ||
queries["Format"] = "JSON" | ||
queries["Timestamp"] = utils.GetTimeInFormatISO8601() | ||
queries["SignatureMethod"] = "HMAC-SHA1" | ||
queries["SignatureVersion"] = "1.0" | ||
queries["SignatureNonce"] = utils.GetNonce() | ||
queries["AccessKeyId"] = cc.AccessKeyId | ||
if cc.SecurityToken != "" { | ||
queries["SecurityToken"] = cc.SecurityToken | ||
} | ||
|
||
bodyForm := make(map[string]string) | ||
bodyForm["RoleArn"] = provider.roleArn | ||
if provider.policy != "" { | ||
bodyForm["Policy"] = provider.policy | ||
} | ||
if provider.externalId != "" { | ||
bodyForm["ExternalId"] = provider.externalId | ||
} | ||
bodyForm["RoleSessionName"] = provider.roleSessionName | ||
bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds) | ||
|
||
// caculate signature | ||
signParams := make(map[string]string) | ||
for key, value := range queries { | ||
signParams[key] = value | ||
} | ||
for key, value := range bodyForm { | ||
signParams[key] = value | ||
} | ||
|
||
stringToSign := utils.GetURLFormedMap(signParams) | ||
stringToSign = strings.Replace(stringToSign, "+", "%20", -1) | ||
stringToSign = strings.Replace(stringToSign, "*", "%2A", -1) | ||
stringToSign = strings.Replace(stringToSign, "%7E", "~", -1) | ||
stringToSign = url.QueryEscape(stringToSign) | ||
stringToSign = method + "&%2F&" + stringToSign | ||
secret := cc.AccessKeySecret + "&" | ||
queries["Signature"] = utils.ShaHmac1(stringToSign, secret) | ||
|
||
querystring := utils.GetURLFormedMap(queries) | ||
// do request | ||
httpUrl := fmt.Sprintf("https://%s/?%s", host, querystring) | ||
|
||
body := utils.GetURLFormedMap(bodyForm) | ||
|
||
httpRequest, err := hookNewRequest(http.NewRequest)(method, httpUrl, strings.NewReader(body)) | ||
if err != nil { | ||
return | ||
} | ||
|
||
// set headers | ||
httpRequest.Header["Accept-Encoding"] = []string{"identity"} | ||
httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"} | ||
httpRequest.Header["x-credentials-provider"] = []string{cc.ProviderName} | ||
httpClient := &http.Client{} | ||
|
||
httpResponse, err := hookDo(httpClient.Do)(httpRequest) | ||
if err != nil { | ||
return | ||
} | ||
|
||
defer httpResponse.Body.Close() | ||
|
||
responseBody, err := ioutil.ReadAll(httpResponse.Body) | ||
if err != nil { | ||
return | ||
} | ||
|
||
if httpResponse.StatusCode != http.StatusOK { | ||
err = errors.New("refresh session token failed: " + string(responseBody)) | ||
return | ||
} | ||
var data assumeRoleResponse | ||
err = json.Unmarshal(responseBody, &data) | ||
if err != nil { | ||
err = fmt.Errorf("refresh RoleArn sts token err, json.Unmarshal fail: %s", err.Error()) | ||
return | ||
} | ||
if data.Credentials == nil { | ||
err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials") | ||
return | ||
} | ||
|
||
if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil { | ||
err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials") | ||
return | ||
} | ||
|
||
session = &sessionCredentials{ | ||
AccessKeyId: *data.Credentials.AccessKeyId, | ||
AccessKeySecret: *data.Credentials.AccessKeySecret, | ||
SecurityToken: *data.Credentials.SecurityToken, | ||
Expiration: *data.Credentials.Expiration, | ||
} | ||
return | ||
} | ||
|
||
func (provider *RAMRoleARNCredentialsProvider) needUpdateCredential() (result bool) { | ||
if provider.expirationTimestamp == 0 { | ||
return true | ||
} | ||
|
||
return provider.expirationTimestamp-time.Now().Unix() <= 180 | ||
} | ||
|
||
func (provider *RAMRoleARNCredentialsProvider) GetCredentials() (cc *Credentials, err error) { | ||
if provider.sessionCredentials == nil || provider.needUpdateCredential() { | ||
// 获取前置凭证 | ||
previousCredentials, err1 := provider.credentialsProvider.GetCredentials() | ||
if err1 != nil { | ||
return nil, err1 | ||
} | ||
sessionCredentials, err2 := provider.getCredentials(previousCredentials) | ||
if err2 != nil { | ||
return nil, err2 | ||
} | ||
|
||
expirationTime, err := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
provider.expirationTimestamp = expirationTime.Unix() | ||
provider.lastUpdateTimestamp = time.Now().Unix() | ||
provider.sessionCredentials = sessionCredentials | ||
} | ||
|
||
cc = &Credentials{ | ||
AccessKeyId: provider.sessionCredentials.AccessKeyId, | ||
AccessKeySecret: provider.sessionCredentials.AccessKeySecret, | ||
SecurityToken: provider.sessionCredentials.SecurityToken, | ||
} | ||
return | ||
} |
Oops, something went wrong.