From b103b6112734c78965bdca0790e7d5f4571c9070 Mon Sep 17 00:00:00 2001 From: Jackson Tian Date: Fri, 23 Aug 2024 17:45:16 +0800 Subject: [PATCH] refine test cases for OIDC --- .../internal/providers/ecs_ram_role.go | 7 - credentials/internal/providers/oidc.go | 77 ++------ credentials/internal/providers/oidc_test.go | 187 +++++++----------- 3 files changed, 96 insertions(+), 175 deletions(-) diff --git a/credentials/internal/providers/ecs_ram_role.go b/credentials/internal/providers/ecs_ram_role.go index b1fde39..ac5b036 100644 --- a/credentials/internal/providers/ecs_ram_role.go +++ b/credentials/internal/providers/ecs_ram_role.go @@ -10,14 +10,12 @@ import ( "time" httputil "github.com/aliyun/credentials-go/credentials/internal/http" - "github.com/aliyun/credentials-go/credentials/internal/utils" ) type ECSRAMRoleCredentialsProvider struct { roleName string metadataTokenDurationSeconds int enableIMDSv2 bool - runtime *utils.Runtime // for sts session *sessionCredentials expirationTimestamp int64 @@ -68,11 +66,6 @@ func (builder *ECSRAMRoleCredentialsProviderBuilder) Build() (provider *ECSRAMRo return } - builder.provider.runtime = &utils.Runtime{ - ConnectTimeout: 5, - ReadTimeout: 5, - } - provider = builder.provider return } diff --git a/credentials/internal/providers/oidc.go b/credentials/internal/providers/oidc.go index 78f3654..7ff0783 100644 --- a/credentials/internal/providers/oidc.go +++ b/credentials/internal/providers/oidc.go @@ -6,12 +6,11 @@ import ( "fmt" "io/ioutil" "net/http" - "net/url" "os" "strconv" - "strings" "time" + httputil "github.com/aliyun/credentials-go/credentials/internal/http" "github.com/aliyun/credentials-go/credentials/internal/utils" ) @@ -139,13 +138,25 @@ func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvi } func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCredentials, err error) { - method := "POST" - host := provider.stsEndpoint + req := &httputil.Request{ + Method: "POST", + Protocol: "https", + Host: provider.stsEndpoint, + Headers: map[string]string{}, + } + + if provider.httpOptions != nil { + req.ConnectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Second + req.ReadTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Second + req.Proxy = provider.httpOptions.Proxy + } + queries := make(map[string]string) queries["Version"] = "2015-04-01" queries["Action"] = "AssumeRoleWithOIDC" queries["Format"] = "JSON" queries["Timestamp"] = utils.GetTimeInFormatISO8601() + req.Queries = queries bodyForm := make(map[string]string) bodyForm["RoleArn"] = provider.roleArn @@ -162,68 +173,22 @@ func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCrede bodyForm["RoleSessionName"] = provider.roleSessionName bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds) - - // caculate signature - signParams := make(map[string]string) - for key, value := range queries { - signParams[key] = value - } - for key, value := range bodyForm { - signParams[key] = value - } - - querystring := utils.GetURLFormedMap(queries) - // do request - httpUrl := fmt.Sprintf("https://%s/?%s", host, querystring) - - body := utils.GetURLFormedMap(bodyForm) - - httpRequest, err := hookNewRequest(http.NewRequest)(method, httpUrl, strings.NewReader(body)) - if err != nil { - return - } + req.Form = bodyForm // set headers - httpRequest.Header["Accept-Encoding"] = []string{"identity"} - httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"} - httpClient := &http.Client{} - - if provider.httpOptions != nil { - httpClient.Timeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Second - proxy := &url.URL{} - if provider.httpOptions.Proxy != "" { - proxy, err = url.Parse(provider.httpOptions.Proxy) - if err != nil { - return - } - } - trans := &http.Transport{} - if proxy != nil && provider.httpOptions.Proxy != "" { - trans.Proxy = http.ProxyURL(proxy) - } - trans.DialContext = utils.Timeout(time.Duration(provider.httpOptions.ConnectTimeout) * time.Second) - httpClient.Transport = trans - } - - httpResponse, err := hookDo(httpClient.Do)(httpRequest) - if err != nil { - return - } - - defer httpResponse.Body.Close() - - responseBody, err := ioutil.ReadAll(httpResponse.Body) + req.Headers["Accept-Encoding"] = "identity" + res, err := httpDo(req) if err != nil { return } - if httpResponse.StatusCode != http.StatusOK { + if res.StatusCode != http.StatusOK { message := "get session token failed: " - err = errors.New(message + string(responseBody)) + err = errors.New(message + string(res.Body)) return } var data assumeRoleResponse - err = json.Unmarshal(responseBody, &data) + err = json.Unmarshal(res.Body, &data) if err != nil { err = fmt.Errorf("get oidc sts token err, json.Unmarshal fail: %s", err.Error()) return diff --git a/credentials/internal/providers/oidc_test.go b/credentials/internal/providers/oidc_test.go index 2553221..79c558e 100644 --- a/credentials/internal/providers/oidc_test.go +++ b/credentials/internal/providers/oidc_test.go @@ -2,16 +2,14 @@ package providers import ( "errors" - "io" - "io/ioutil" - "net/http" "os" "path" - "strconv" "strings" "testing" "time" + httputil "github.com/aliyun/credentials-go/credentials/internal/http" + "github.com/aliyun/credentials-go/credentials/internal/utils" "github.com/stretchr/testify/assert" ) @@ -25,15 +23,24 @@ func TestOIDCCredentialsProviderGetCredentialsWithError(t *testing.T) { WithRoleSessionName("rsn"). WithPolicy("policy"). WithDurationSeconds(1000). + WithHttpOptions(&HttpOptions{ + ConnectTimeout: 10, + }). Build() assert.Nil(t, err) + assert.Equal(t, 100, p.httpOptions.ConnectTimeout) _, err = p.GetCredentials() assert.NotNil(t, err) assert.Contains(t, err.Error(), "AuthenticationFail.NoPermission") } func TestNewOIDCCredentialsProvider(t *testing.T) { + rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "ALIBABA_CLOUD_ROLE_ARN") + defer func() { + rollback() + }() + _, err := NewOIDCCredentialsProviderBuilder().Build() assert.NotNil(t, err) assert.Equal(t, "the OIDCTokenFilePath is empty", err.Error()) @@ -73,12 +80,6 @@ func TestNewOIDCCredentialsProvider(t *testing.T) { os.Setenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "provider_arn_from_env") os.Setenv("ALIBABA_CLOUD_ROLE_ARN", "role_arn_from_env") - defer func() { - os.Unsetenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE") - os.Unsetenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN") - os.Unsetenv("ALIBABA_CLOUD_ROLE_ARN") - }() - p, err = NewOIDCCredentialsProviderBuilder(). Build() assert.Nil(t, err) @@ -123,6 +124,9 @@ func TestNewOIDCCredentialsProvider(t *testing.T) { } func TestOIDCCredentialsProvider_getCredentials(t *testing.T) { + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + // case 0: invalid oidc token file path p, err := NewOIDCCredentialsProviderBuilder(). WithOIDCTokenFilePath("/path/to/invalid/oidc.token"). @@ -151,105 +155,70 @@ func TestOIDCCredentialsProvider_getCredentials(t *testing.T) { Build() assert.Nil(t, err) - originNewRequest := hookNewRequest - defer func() { hookNewRequest = originNewRequest }() - - hookNewRequest = func(fn newReuqest) newReuqest { - return func(method, url string, body io.Reader) (*http.Request, error) { - return nil, errors.New("new http request failed") - } - } - - _, err = p.getCredentials() - assert.NotNil(t, err) - assert.Equal(t, "new http request failed", err.Error()) - - // reset new request - hookNewRequest = originNewRequest - - originDo := hookDo - defer func() { hookDo = originDo }() - // case 2: server error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "mock server error", err.Error()) - // case 3: mock read response error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - status := strconv.Itoa(200) - res = &http.Response{ - Proto: "HTTP/1.1", - ProtoMajor: 1, - Header: map[string][]string{}, - StatusCode: 200, - Status: status + " " + http.StatusText(200), - } - res.Body = ioutil.NopCloser(&errorReader{}) - return - } - } - _, err = p.getCredentials() - assert.NotNil(t, err) - assert.Equal(t, "read failed", err.Error()) - - // case 4: 4xx error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(400, "4xx error") - return + // case 3: 4xx error + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 400, + Body: []byte("4xx error"), } + return } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "get session token failed: 4xx error", err.Error()) - // case 5: invalid json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, "invalid json") - return + // case 4: invalid json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("invalid json"), } + return } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "get oidc sts token err, json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error()) - // case 6: empty response json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, "null") - return + // case 5: empty response json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("null"), } + return } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "get oidc sts token err, fail to get credentials", err.Error()) - // case 7: empty session ak response json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"Credentials": {}}`) - return + // case 6: empty session ak response json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"Credentials": {}}`), } + return } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "refresh RoleArn sts token err, fail to get credentials", err.Error()) - // case 8: mock ok value - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"Credentials": {"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token"}}`) - return + // case 7: mock ok value + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"Credentials": {"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token"}}`), } + return } creds, err := p.getCredentials() assert.Nil(t, err) @@ -268,8 +237,8 @@ func TestOIDCCredentialsProvider_getCredentials(t *testing.T) { } func TestOIDCCredentialsProvider_getCredentialsWithRequestCheck(t *testing.T) { - originDo := hookDo - defer func() { hookDo = originDo }() + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() // case 1: mock new http request failed wd, _ := os.Getwd() @@ -286,31 +255,25 @@ func TestOIDCCredentialsProvider_getCredentialsWithRequestCheck(t *testing.T) { assert.Nil(t, err) // case 1: server error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - assert.Equal(t, "sts.aliyuncs.com", req.Host) - assert.Contains(t, req.URL.String(), "Action=AssumeRoleWithOIDC") - body, err := ioutil.ReadAll(req.Body) - assert.Nil(t, err) - bodyString := string(body) - assert.Contains(t, bodyString, "Policy=policy") - assert.Contains(t, bodyString, "RoleArn=roleArn") - assert.Contains(t, bodyString, "RoleSessionName=rsn") - assert.Contains(t, bodyString, "DurationSeconds=1000") - - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + assert.Equal(t, "sts.aliyuncs.com", req.Host) + assert.Equal(t, "AssumeRoleWithOIDC", req.Queries["Action"]) + assert.Equal(t, "policy", req.Form["Policy"]) + assert.Equal(t, "roleArn", req.Form["RoleArn"]) + assert.Equal(t, "rsn", req.Form["RoleSessionName"]) + assert.Equal(t, "1000", req.Form["DurationSeconds"]) + + err = errors.New("mock server error") + return } - _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "mock server error", err.Error()) } func TestOIDCCredentialsProviderGetCredentials(t *testing.T) { - originDo := hookDo - defer func() { hookDo = originDo }() + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() // case 1: mock new http request failed wd, _ := os.Getwd() @@ -326,34 +289,34 @@ func TestOIDCCredentialsProviderGetCredentials(t *testing.T) { assert.Nil(t, err) - // case 1: get credentials failed - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + // case 2: get credentials failed + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } _, err = p.GetCredentials() assert.NotNil(t, err) assert.Equal(t, "mock server error", err.Error()) // case 2: get invalid expiration - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"invalidexpiration","SecurityToken":"ststoken"}}`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"invalidexpiration","SecurityToken":"ststoken"}}`), } + return } _, err = p.GetCredentials() assert.NotNil(t, err) assert.Equal(t, "parsing time \"invalidexpiration\" as \"2006-01-02T15:04:05Z\": cannot parse \"invalidexpiration\" as \"2006\"", err.Error()) // case 3: happy result - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"ststoken"}}`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"ststoken"}}`), } + return } cc, err := p.GetCredentials() assert.Nil(t, err)