diff --git a/credentials/credential.go b/credentials/credential.go index 70e6c60..d555e91 100644 --- a/credentials/credential.go +++ b/credentials/credential.go @@ -207,30 +207,28 @@ func NewCredential(config *Config) (credential Credential, err error) { case "credentials_uri": credential = newURLCredential(tea.StringValue(config.Url)) case "oidc_role_arn": - err = checkoutAssumeRamoidc(config) - if err != nil { - return - } runtime := &utils.Runtime{ Host: tea.StringValue(config.Host), Proxy: tea.StringValue(config.Proxy), ReadTimeout: tea.IntValue(config.Timeout), ConnectTimeout: tea.IntValue(config.ConnectTimeout), - STSEndpoint: tea.StringValue(config.STSEndpoint), } - credential, err = newOIDCRoleArnCredential( - tea.StringValue(config.AccessKeyId), - tea.StringValue(config.AccessKeySecret), - tea.StringValue(config.RoleArn), - tea.StringValue(config.OIDCProviderArn), - tea.StringValue(config.OIDCTokenFilePath), - tea.StringValue(config.RoleSessionName), - tea.StringValue(config.Policy), - tea.IntValue(config.RoleSessionExpiration), - runtime) + + provider, err := providers.NewOIDCCredentialsProviderBuilder(). + WithRoleArn(tea.StringValue(config.RoleArn)). + WithOIDCTokenFilePath(tea.StringValue(config.OIDCTokenFilePath)). + WithOIDCProviderARN(tea.StringValue(config.OIDCProviderArn)). + WithDurationSeconds(tea.IntValue(config.RoleSessionExpiration)). + WithPolicy(tea.StringValue(config.Policy)). + WithRoleSessionName(tea.StringValue(config.RoleSessionName)). + WithSTSEndpoint(tea.StringValue(config.STSEndpoint)). + WithRuntime(runtime). + Build() + if err != nil { - return + return nil, err } + credential = fromCredentialsProvider("oidc_role_arn", provider) case "access_key": provider, err := providers.NewStaticAKCredentialsProviderBuilder(). WithAccessKeyId(tea.StringValue(config.AccessKeyId)). @@ -358,18 +356,6 @@ func checkRSAKeyPair(config *Config) (err error) { return } -func checkoutAssumeRamoidc(config *Config) (err error) { - if tea.StringValue(config.RoleArn) == "" { - err = errors.New("RoleArn cannot be empty") - return - } - if tea.StringValue(config.OIDCProviderArn) == "" { - err = errors.New("OIDCProviderArn cannot be empty") - return - } - return -} - func doAction(request *request.CommonRequest, runtime *utils.Runtime) (content []byte, err error) { var urlEncoded string if request.BodyParams != nil { diff --git a/credentials/credential_test.go b/credentials/credential_test.go index 1c0bc1d..2459954 100644 --- a/credentials/credential_test.go +++ b/credentials/credential_test.go @@ -213,25 +213,22 @@ func TestNewCredentialWithOIDC(t *testing.T) { config.SetType("oidc_role_arn") cred, err := NewCredential(config) assert.NotNil(t, err) - assert.Equal(t, "RoleArn cannot be empty", err.Error()) + assert.Equal(t, "the OIDCTokenFilePath is empty", err.Error()) assert.Nil(t, cred) - config.SetRoleArn("role_arn") + config.SetOIDCTokenFilePath("oidc_token_file_path_test") cred, err = NewCredential(config) assert.NotNil(t, err) - assert.Equal(t, "OIDCProviderArn cannot be empty", err.Error()) + assert.Equal(t, "the OIDCProviderARN is empty", err.Error()) assert.Nil(t, cred) - config.SetOIDCProviderArn("oidc_provider_arn_test"). - SetRoleArn("role_arn_test") + config.SetOIDCProviderArn("oidc_provider_arn_test") cred, err = NewCredential(config) assert.NotNil(t, err) - assert.Equal(t, "the OIDC token file path is empty", err.Error()) + assert.Equal(t, "the RoleArn is empty", err.Error()) assert.Nil(t, cred) - config.SetOIDCProviderArn("oidc_provider_arn_test"). - SetOIDCTokenFilePath("oidc_token_file_path_test"). - SetRoleArn("role_arn_test") + config.SetRoleArn("role_arn_test") cred, err = NewCredential(config) assert.Nil(t, err) assert.NotNil(t, cred) diff --git a/credentials/internal/providers/fixtures/mock_oidctoken b/credentials/internal/providers/fixtures/mock_oidctoken new file mode 100644 index 0000000..6e5fed8 --- /dev/null +++ b/credentials/internal/providers/fixtures/mock_oidctoken @@ -0,0 +1 @@ +mock oidc token \ No newline at end of file diff --git a/credentials/internal/providers/oidc.go b/credentials/internal/providers/oidc.go new file mode 100644 index 0000000..3b02ee1 --- /dev/null +++ b/credentials/internal/providers/oidc.go @@ -0,0 +1,267 @@ +package providers + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/aliyun/credentials-go/credentials/internal/utils" +) + +type OIDCCredentialsProvider struct { + oidcProviderARN string + oidcTokenFilePath string + roleArn string + roleSessionName string + durationSeconds int + policy string + stsRegionId string + stsEndpoint string + lastUpdateTimestamp int64 + expirationTimestamp int64 + sessionCredentials *sessionCredentials + runtime *utils.Runtime +} + +type OIDCCredentialsProviderBuilder struct { + provider *OIDCCredentialsProvider +} + +func NewOIDCCredentialsProviderBuilder() *OIDCCredentialsProviderBuilder { + return &OIDCCredentialsProviderBuilder{ + provider: &OIDCCredentialsProvider{}, + } +} + +func (b *OIDCCredentialsProviderBuilder) WithOIDCProviderARN(oidcProviderArn string) *OIDCCredentialsProviderBuilder { + b.provider.oidcProviderARN = oidcProviderArn + return b +} + +func (b *OIDCCredentialsProviderBuilder) WithOIDCTokenFilePath(oidcTokenFilePath string) *OIDCCredentialsProviderBuilder { + b.provider.oidcTokenFilePath = oidcTokenFilePath + return b +} + +func (b *OIDCCredentialsProviderBuilder) WithRoleArn(roleArn string) *OIDCCredentialsProviderBuilder { + b.provider.roleArn = roleArn + return b +} + +func (b *OIDCCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *OIDCCredentialsProviderBuilder { + b.provider.roleSessionName = roleSessionName + return b +} + +func (b *OIDCCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *OIDCCredentialsProviderBuilder { + b.provider.durationSeconds = durationSeconds + return b +} + +func (b *OIDCCredentialsProviderBuilder) WithStsRegionId(regionId string) *OIDCCredentialsProviderBuilder { + b.provider.stsRegionId = regionId + return b +} + +func (b *OIDCCredentialsProviderBuilder) WithPolicy(policy string) *OIDCCredentialsProviderBuilder { + b.provider.policy = policy + return b +} + +func (b *OIDCCredentialsProviderBuilder) WithSTSEndpoint(stsEndpoint string) *OIDCCredentialsProviderBuilder { + b.provider.stsEndpoint = stsEndpoint + return b +} + +func (b *OIDCCredentialsProviderBuilder) WithRuntime(runtime *utils.Runtime) *OIDCCredentialsProviderBuilder { + b.provider.runtime = runtime + return b +} + +func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvider, err error) { + if b.provider.roleSessionName == "" { + b.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10) + } + + if b.provider.oidcTokenFilePath == "" { + b.provider.oidcTokenFilePath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE") + } + + if b.provider.oidcTokenFilePath == "" { + err = errors.New("the OIDCTokenFilePath is empty") + return + } + + if b.provider.oidcProviderARN == "" { + b.provider.oidcProviderARN = os.Getenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN") + } + + if b.provider.oidcProviderARN == "" { + err = errors.New("the OIDCProviderARN is empty") + return + } + + if b.provider.roleArn == "" { + b.provider.roleArn = os.Getenv("ALIBABA_CLOUD_ROLE_ARN") + } + + if b.provider.roleArn == "" { + err = errors.New("the RoleArn is empty") + return + } + + if b.provider.durationSeconds == 0 { + b.provider.durationSeconds = 3600 + } + + if b.provider.durationSeconds < 900 { + err = errors.New("the Assume Role session duration should be in the range of 15min - max duration seconds") + } + + if b.provider.stsEndpoint == "" { + if b.provider.stsRegionId != "" { + b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", b.provider.stsRegionId) + } else { + b.provider.stsEndpoint = "sts.aliyuncs.com" + } + } + + provider = b.provider + return +} + +func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCredentials, err error) { + method := "POST" + host := provider.stsEndpoint + queries := make(map[string]string) + queries["Version"] = "2015-04-01" + queries["Action"] = "AssumeRoleWithOIDC" + queries["Format"] = "JSON" + queries["Timestamp"] = utils.GetTimeInFormatISO8601() + + bodyForm := make(map[string]string) + bodyForm["RoleArn"] = provider.roleArn + bodyForm["OIDCProviderArn"] = provider.oidcProviderARN + token, err := ioutil.ReadFile(provider.oidcTokenFilePath) + if err != nil { + return + } + + bodyForm["OIDCToken"] = string(token) + if provider.policy != "" { + bodyForm["Policy"] = provider.policy + } + + bodyForm["RoleSessionName"] = provider.roleSessionName + bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds) + + // caculate signature + signParams := make(map[string]string) + for key, value := range queries { + signParams[key] = value + } + for key, value := range bodyForm { + signParams[key] = value + } + + querystring := utils.GetURLFormedMap(queries) + // do request + httpUrl := fmt.Sprintf("https://%s/?%s", host, querystring) + + body := utils.GetURLFormedMap(bodyForm) + + httpRequest, err := hookNewRequest(http.NewRequest)(method, httpUrl, strings.NewReader(body)) + if err != nil { + return + } + + // set headers + httpRequest.Header["Accept-Encoding"] = []string{"identity"} + httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"} + httpClient := &http.Client{} + + httpResponse, err := hookDo(httpClient.Do)(httpRequest) + if err != nil { + return + } + + defer httpResponse.Body.Close() + + responseBody, err := ioutil.ReadAll(httpResponse.Body) + if err != nil { + return + } + + if httpResponse.StatusCode != http.StatusOK { + message := "get session token failed: " + err = errors.New(message + string(responseBody)) + return + } + var data assumeRoleResponse + err = json.Unmarshal(responseBody, &data) + if err != nil { + err = fmt.Errorf("get oidc sts token err, json.Unmarshal fail: %s", err.Error()) + return + } + if data.Credentials == nil { + err = fmt.Errorf("get oidc sts token err, fail to get credentials") + return + } + + if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil { + err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials") + return + } + + session = &sessionCredentials{ + AccessKeyId: *data.Credentials.AccessKeyId, + AccessKeySecret: *data.Credentials.AccessKeySecret, + SecurityToken: *data.Credentials.SecurityToken, + Expiration: *data.Credentials.Expiration, + } + return +} + +func (provider *OIDCCredentialsProvider) needUpdateCredential() (result bool) { + if provider.expirationTimestamp == 0 { + return true + } + + return provider.expirationTimestamp-time.Now().Unix() <= 180 +} + +func (provider *OIDCCredentialsProvider) GetCredentials() (cc *Credentials, err error) { + if provider.sessionCredentials == nil || provider.needUpdateCredential() { + sessionCredentials, err1 := provider.getCredentials() + if err1 != nil { + return nil, err1 + } + + provider.sessionCredentials = sessionCredentials + expirationTime, err2 := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration) + if err2 != nil { + return nil, err2 + } + + provider.lastUpdateTimestamp = time.Now().Unix() + provider.expirationTimestamp = expirationTime.Unix() + } + + cc = &Credentials{ + AccessKeyId: provider.sessionCredentials.AccessKeyId, + AccessKeySecret: provider.sessionCredentials.AccessKeySecret, + SecurityToken: provider.sessionCredentials.SecurityToken, + ProviderName: provider.GetProviderName(), + } + return +} + +func (provider *OIDCCredentialsProvider) GetProviderName() string { + return "oidc" +} diff --git a/credentials/internal/providers/oidc_test.go b/credentials/internal/providers/oidc_test.go new file mode 100644 index 0000000..96960e4 --- /dev/null +++ b/credentials/internal/providers/oidc_test.go @@ -0,0 +1,364 @@ +package providers + +import ( + "errors" + "io" + "io/ioutil" + "net/http" + "os" + "path" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestOIDCCredentialsProviderGetCredentialsWithError(t *testing.T) { + wd, _ := os.Getwd() + p, err := NewOIDCCredentialsProviderBuilder(). + // read a normal token + WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")). + WithOIDCProviderARN("provider-arn"). + WithRoleArn("roleArn"). + WithRoleSessionName("rsn"). + WithPolicy("policy"). + WithDurationSeconds(1000). + Build() + + assert.Nil(t, err) + _, err = p.GetCredentials() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "AuthenticationFail.NoPermission") +} + +func TestNewOIDCCredentialsProvider(t *testing.T) { + _, err := NewOIDCCredentialsProviderBuilder().Build() + assert.NotNil(t, err) + assert.Equal(t, "the OIDCTokenFilePath is empty", err.Error()) + + _, err = NewOIDCCredentialsProviderBuilder().WithOIDCTokenFilePath("/path/to/invalid/oidc.token").Build() + assert.NotNil(t, err) + assert.Equal(t, "the OIDCProviderARN is empty", err.Error()) + + _, err = NewOIDCCredentialsProviderBuilder(). + WithOIDCTokenFilePath("/path/to/invalid/oidc.token"). + WithOIDCProviderARN("provider-arn"). + Build() + assert.NotNil(t, err) + assert.Equal(t, "the RoleArn is empty", err.Error()) + + p, err := NewOIDCCredentialsProviderBuilder(). + WithOIDCTokenFilePath("/path/to/invalid/oidc.token"). + WithOIDCProviderARN("provider-arn"). + WithRoleArn("roleArn"). + Build() + assert.Nil(t, err) + + assert.Equal(t, "/path/to/invalid/oidc.token", p.oidcTokenFilePath) + assert.True(t, strings.HasPrefix(p.roleSessionName, "credentials-go-")) + assert.Equal(t, 3600, p.durationSeconds) + + _, err = NewOIDCCredentialsProviderBuilder(). + WithOIDCTokenFilePath("/path/to/invalid/oidc.token"). + WithOIDCProviderARN("provider-arn"). + WithRoleArn("roleArn"). + WithDurationSeconds(100). + Build() + assert.NotNil(t, err) + assert.Equal(t, "the Assume Role session duration should be in the range of 15min - max duration seconds", err.Error()) + + os.Setenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "/path/from/env") + os.Setenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "provider_arn_from_env") + os.Setenv("ALIBABA_CLOUD_ROLE_ARN", "role_arn_from_env") + + defer func() { + os.Unsetenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE") + os.Unsetenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN") + os.Unsetenv("ALIBABA_CLOUD_ROLE_ARN") + }() + + p, err = NewOIDCCredentialsProviderBuilder(). + Build() + assert.Nil(t, err) + + assert.Equal(t, "/path/from/env", p.oidcTokenFilePath) + assert.Equal(t, "provider_arn_from_env", p.oidcProviderARN) + assert.Equal(t, "role_arn_from_env", p.roleArn) + // sts endpoint: default + assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint) + // sts endpoint: with sts endpoint + p, err = NewOIDCCredentialsProviderBuilder(). + WithSTSEndpoint("sts.cn-shanghai.aliyuncs.com"). + Build() + assert.Nil(t, err) + assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint) + + // sts endpoint: with sts regionId + p, err = NewOIDCCredentialsProviderBuilder(). + WithStsRegionId("cn-beijing"). + Build() + assert.Nil(t, err) + assert.Equal(t, "sts.cn-beijing.aliyuncs.com", p.stsEndpoint) + + p, err = NewOIDCCredentialsProviderBuilder(). + WithOIDCTokenFilePath("/path/to/invalid/oidc.token"). + WithOIDCProviderARN("provider-arn"). + WithRoleArn("roleArn"). + WithRoleSessionName("rsn"). + WithStsRegionId("cn-hangzhou"). + WithPolicy("policy"). + Build() + assert.Nil(t, err) + + assert.Equal(t, "/path/to/invalid/oidc.token", p.oidcTokenFilePath) + assert.Equal(t, "provider-arn", p.oidcProviderARN) + assert.Equal(t, "roleArn", p.roleArn) + assert.Equal(t, "rsn", p.roleSessionName) + assert.Equal(t, "cn-hangzhou", p.stsRegionId) + assert.Equal(t, "policy", p.policy) + assert.Equal(t, 3600, p.durationSeconds) + assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint) +} + +func TestOIDCCredentialsProvider_getCredentials(t *testing.T) { + // case 0: invalid oidc token file path + p, err := NewOIDCCredentialsProviderBuilder(). + WithOIDCTokenFilePath("/path/to/invalid/oidc.token"). + WithOIDCProviderARN("provider-arn"). + WithRoleArn("roleArn"). + WithRoleSessionName("rsn"). + WithStsRegionId("cn-hangzhou"). + WithPolicy("policy"). + Build() + assert.Nil(t, err) + + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "open /path/to/invalid/oidc.token: no such file or directory", err.Error()) + + // case 1: mock new http request failed + wd, _ := os.Getwd() + p, err = NewOIDCCredentialsProviderBuilder(). + // read a normal token + WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")). + WithOIDCProviderARN("provider-arn"). + WithRoleArn("roleArn"). + WithRoleSessionName("rsn"). + WithStsRegionId("cn-hangzhou"). + WithPolicy("policy"). + Build() + assert.Nil(t, err) + + originNewRequest := hookNewRequest + defer func() { hookNewRequest = originNewRequest }() + + hookNewRequest = func(fn newReuqest) newReuqest { + return func(method, url string, body io.Reader) (*http.Request, error) { + return nil, errors.New("new http request failed") + } + } + + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "new http request failed", err.Error()) + + // reset new request + hookNewRequest = originNewRequest + + originDo := hookDo + defer func() { hookDo = originDo }() + + // case 2: server error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "mock server error", err.Error()) + + // case 3: mock read response error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + status := strconv.Itoa(200) + res = &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + Header: map[string][]string{}, + StatusCode: 200, + Status: status + " " + http.StatusText(200), + } + res.Body = ioutil.NopCloser(&errorReader{}) + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "read failed", err.Error()) + + // case 4: 4xx error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(400, "4xx error") + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "get session token failed: 4xx error", err.Error()) + + // case 5: invalid json + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, "invalid json") + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "get oidc sts token err, json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error()) + + // case 6: empty response json + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, "null") + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "get oidc sts token err, fail to get credentials", err.Error()) + + // case 7: empty session ak response json + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, `{"Credentials": {}}`) + return + } + } + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "refresh RoleArn sts token err, fail to get credentials", err.Error()) + + // case 8: mock ok value + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, `{"Credentials": {"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token"}}`) + return + } + } + creds, err := p.getCredentials() + assert.Nil(t, err) + assert.Equal(t, "saki", creds.AccessKeyId) + assert.Equal(t, "saks", creds.AccessKeySecret) + assert.Equal(t, "token", creds.SecurityToken) + assert.Equal(t, "2021-10-20T04:27:09Z", creds.Expiration) + + // needUpdateCredential + assert.True(t, p.needUpdateCredential()) + p.expirationTimestamp = time.Now().Unix() + assert.True(t, p.needUpdateCredential()) + + p.expirationTimestamp = time.Now().Unix() + 300 + assert.False(t, p.needUpdateCredential()) +} + +func TestOIDCCredentialsProvider_getCredentialsWithRequestCheck(t *testing.T) { + originDo := hookDo + defer func() { hookDo = originDo }() + + // case 1: mock new http request failed + wd, _ := os.Getwd() + p, err := NewOIDCCredentialsProviderBuilder(). + // read a normal token + WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")). + WithOIDCProviderARN("provider-arn"). + WithRoleArn("roleArn"). + WithRoleSessionName("rsn"). + WithPolicy("policy"). + WithDurationSeconds(1000). + Build() + + assert.Nil(t, err) + + // case 1: server error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + assert.Equal(t, "sts.aliyuncs.com", req.Host) + assert.Contains(t, req.URL.String(), "Action=AssumeRoleWithOIDC") + body, err := ioutil.ReadAll(req.Body) + assert.Nil(t, err) + bodyString := string(body) + assert.Contains(t, bodyString, "Policy=policy") + assert.Contains(t, bodyString, "RoleArn=roleArn") + assert.Contains(t, bodyString, "RoleSessionName=rsn") + assert.Contains(t, bodyString, "DurationSeconds=1000") + + err = errors.New("mock server error") + return + } + } + + _, err = p.getCredentials() + assert.NotNil(t, err) + assert.Equal(t, "mock server error", err.Error()) +} + +func TestOIDCCredentialsProviderGetCredentials(t *testing.T) { + originDo := hookDo + defer func() { hookDo = originDo }() + + // case 1: mock new http request failed + wd, _ := os.Getwd() + p, err := NewOIDCCredentialsProviderBuilder(). + // read a normal token + WithOIDCTokenFilePath(path.Join(wd, "fixtures/mock_oidctoken")). + WithOIDCProviderARN("provider-arn"). + WithRoleArn("roleArn"). + WithRoleSessionName("rsn"). + WithPolicy("policy"). + WithDurationSeconds(1000). + Build() + + assert.Nil(t, err) + + // case 1: get credentials failed + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = p.GetCredentials() + assert.NotNil(t, err) + assert.Equal(t, "mock server error", err.Error()) + + // case 2: get invalid expiration + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, `{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"invalidexpiration","SecurityToken":"ststoken"}}`) + return + } + } + _, err = p.GetCredentials() + assert.NotNil(t, err) + assert.Equal(t, "parsing time \"invalidexpiration\" as \"2006-01-02T15:04:05Z\": cannot parse \"invalidexpiration\" as \"2006\"", err.Error()) + + // case 3: happy result + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = mockResponse(200, `{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"ststoken"}}`) + return + } + } + cc, err := p.GetCredentials() + assert.Nil(t, err) + assert.Equal(t, "akid", cc.AccessKeyId) + assert.Equal(t, "aksecret", cc.AccessKeySecret) + assert.Equal(t, "ststoken", cc.SecurityToken) + assert.True(t, p.needUpdateCredential()) +} diff --git a/credentials/oidc_credentials_provider.go b/credentials/oidc_credentials_provider.go deleted file mode 100644 index 272a6da..0000000 --- a/credentials/oidc_credentials_provider.go +++ /dev/null @@ -1,206 +0,0 @@ -package credentials - -import ( - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "os" - "strconv" - "time" - - "github.com/alibabacloud-go/tea/tea" - "github.com/aliyun/credentials-go/credentials/internal/utils" - "github.com/aliyun/credentials-go/credentials/request" -) - -// OIDCCredential is a kind of credentials -type OIDCCredentialsProvider struct { - *credentialUpdater - AccessKeyId string - AccessKeySecret string - RoleArn string - OIDCProviderArn string - OIDCTokenFilePath string - Policy string - RoleSessionName string - RoleSessionExpiration int - sessionCredential *sessionCredential - runtime *utils.Runtime -} - -type OIDCResponse struct { - Credentials *credentialsInResponse `json:"Credentials" xml:"Credentials"` -} - -type OIDCcredentialsInResponse struct { - AccessKeyId string `json:"AccessKeyId" xml:"AccessKeyId"` - AccessKeySecret string `json:"AccessKeySecret" xml:"AccessKeySecret"` - SecurityToken string `json:"SecurityToken" xml:"SecurityToken"` - Expiration string `json:"Expiration" xml:"Expiration"` -} - -func newOIDCRoleArnCredential(accessKeyId, accessKeySecret, roleArn, OIDCProviderArn, OIDCTokenFilePath, RoleSessionName, policy string, RoleSessionExpiration int, runtime *utils.Runtime) (provider *OIDCCredentialsProvider, err error) { - if OIDCTokenFilePath == "" { - OIDCTokenFilePath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE") - } - - if OIDCTokenFilePath == "" { - err = errors.New("the OIDC token file path is empty") - return - } - - provider = &OIDCCredentialsProvider{ - AccessKeyId: accessKeyId, - AccessKeySecret: accessKeySecret, - RoleArn: roleArn, - OIDCProviderArn: OIDCProviderArn, - OIDCTokenFilePath: OIDCTokenFilePath, - RoleSessionName: RoleSessionName, - Policy: policy, - RoleSessionExpiration: RoleSessionExpiration, - credentialUpdater: new(credentialUpdater), - runtime: runtime, - } - return -} - -func (e *OIDCCredentialsProvider) GetCredential() (*CredentialModel, error) { - if e.sessionCredential == nil || e.needUpdateCredential() { - err := e.updateCredential() - if err != nil { - return nil, err - } - } - credential := &CredentialModel{ - AccessKeyId: tea.String(e.sessionCredential.AccessKeyId), - AccessKeySecret: tea.String(e.sessionCredential.AccessKeySecret), - SecurityToken: tea.String(e.sessionCredential.SecurityToken), - Type: tea.String("oidc_role_arn"), - } - return credential, nil -} - -// GetAccessKeyId reutrns OIDCCredential's AccessKeyId -// if AccessKeyId is not exist or out of date, the function will update it. -func (r *OIDCCredentialsProvider) GetAccessKeyId() (accessKeyId *string, err error) { - c, err := r.GetCredential() - if err != nil { - return - } - - accessKeyId = c.AccessKeyId - return -} - -// GetAccessSecret reutrns OIDCCredential's AccessKeySecret -// if AccessKeySecret is not exist or out of date, the function will update it. -func (r *OIDCCredentialsProvider) GetAccessKeySecret() (accessKeySecret *string, err error) { - c, err := r.GetCredential() - if err != nil { - return - } - - accessKeySecret = c.AccessKeySecret - return -} - -// GetSecurityToken reutrns OIDCCredential's SecurityToken -// if SecurityToken is not exist or out of date, the function will update it. -func (r *OIDCCredentialsProvider) GetSecurityToken() (securityToken *string, err error) { - c, err := r.GetCredential() - if err != nil { - return - } - - securityToken = c.SecurityToken - return -} - -// GetBearerToken is useless OIDCCredential -func (r *OIDCCredentialsProvider) GetBearerToken() *string { - return tea.String("") -} - -// GetType reutrns OIDCCredential's type -func (r *OIDCCredentialsProvider) GetType() *string { - return tea.String("oidc_role_arn") -} - -var getFileContent = func(filePath string) (content string, err error) { - bytes, err := ioutil.ReadFile(filePath) - if err != nil { - return - } - - if len(bytes) == 0 { - err = fmt.Errorf("the content of %s is empty", filePath) - } - - content = string(bytes) - return -} - -func (r *OIDCCredentialsProvider) updateCredential() (err error) { - if r.runtime == nil { - r.runtime = new(utils.Runtime) - } - request := request.NewCommonRequest() - request.Domain = "sts.aliyuncs.com" - if r.runtime.STSEndpoint != "" { - request.Domain = r.runtime.STSEndpoint - } - request.Scheme = "HTTPS" - request.Method = "POST" - request.QueryParams["Timestamp"] = utils.GetTimeInFormatISO8601() - request.QueryParams["Action"] = "AssumeRoleWithOIDC" - request.QueryParams["Format"] = "JSON" - request.BodyParams["RoleArn"] = r.RoleArn - request.BodyParams["OIDCProviderArn"] = r.OIDCProviderArn - token, err := getFileContent(r.OIDCTokenFilePath) - if err != nil { - return fmt.Errorf("read oidc token file failed: %s", err.Error()) - } - - request.BodyParams["OIDCToken"] = token - if r.Policy != "" { - request.QueryParams["Policy"] = r.Policy - } - if r.RoleSessionExpiration > 0 { - request.QueryParams["DurationSeconds"] = strconv.Itoa(r.RoleSessionExpiration) - } - request.QueryParams["RoleSessionName"] = r.RoleSessionName - request.QueryParams["Version"] = "2015-04-01" - request.QueryParams["SignatureNonce"] = utils.GetUUID() - request.Headers["Host"] = request.Domain - request.Headers["Accept-Encoding"] = "identity" - request.Headers["content-type"] = "application/x-www-form-urlencoded" - request.URL = request.BuildURL() - content, err := doAction(request, r.runtime) - if err != nil { - return fmt.Errorf("get sts token failed with: %s", err.Error()) - } - var resp *OIDCResponse - err = json.Unmarshal(content, &resp) - if err != nil { - return fmt.Errorf("get sts token failed with: Json.Unmarshal fail: %s", err.Error()) - } - if resp == nil || resp.Credentials == nil { - return fmt.Errorf("get sts token failed with: credentials is empty") - } - respCredentials := resp.Credentials - if respCredentials.AccessKeyId == "" || respCredentials.AccessKeySecret == "" || respCredentials.SecurityToken == "" || respCredentials.Expiration == "" { - return fmt.Errorf("get sts token failed with: AccessKeyId: %s, AccessKeySecret: %s, SecurityToken: %s, Expiration: %s", respCredentials.AccessKeyId, respCredentials.AccessKeySecret, respCredentials.SecurityToken, respCredentials.Expiration) - } - - expirationTime, err := time.Parse("2006-01-02T15:04:05Z", respCredentials.Expiration) - r.lastUpdateTimestamp = time.Now().Unix() - r.credentialExpiration = int(expirationTime.Unix() - time.Now().Unix()) - r.sessionCredential = &sessionCredential{ - AccessKeyId: respCredentials.AccessKeyId, - AccessKeySecret: respCredentials.AccessKeySecret, - SecurityToken: respCredentials.SecurityToken, - } - - return -} diff --git a/credentials/oidc_credentials_provider_test.go b/credentials/oidc_credentials_provider_test.go deleted file mode 100644 index 7f69818..0000000 --- a/credentials/oidc_credentials_provider_test.go +++ /dev/null @@ -1,198 +0,0 @@ -package credentials - -import ( - "errors" - "net/http" - "os" - "path" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/aliyun/credentials-go/credentials/internal/utils" -) - -func TestNewOidcCredentialsProvider(t *testing.T) { - _, err := newOIDCRoleArnCredential("accessKeyId", "accessKeySecret", "RoleArn", "OIDCProviderArn", "", "roleSessionName", "Policy", 3600, nil) - assert.NotNil(t, err) - assert.Equal(t, "the OIDC token file path is empty", err.Error()) - - // get oidc token path from env - os.Setenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "/path/to/oidc_token") - provider, err := newOIDCRoleArnCredential("accessKeyId", "accessKeySecret", "RoleArn", "OIDCProviderArn", "", "roleSessionName", "Policy", 3600, nil) - assert.Nil(t, err) - assert.Equal(t, "/path/to/oidc_token", provider.OIDCTokenFilePath) - - os.Unsetenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE") - provider, err = newOIDCRoleArnCredential("accessKeyId", "accessKeySecret", "RoleArn", "OIDCProviderArn", "/path/to/oidc_token_args", "roleSessionName", "Policy", 3600, nil) - assert.Nil(t, err) - assert.Equal(t, "/path/to/oidc_token_args", provider.OIDCTokenFilePath) -} - -func Test_oidcCredential_updateCredential(t *testing.T) { - oidcCredential, err := newOIDCRoleArnCredential("accessKeyId", "accessKeySecret", "RoleArn", "OIDCProviderArn", "/path/to/tokenFilePath", "roleSessionName", "Policy", 3600, nil) - assert.Nil(t, err) - - c, err := oidcCredential.GetCredential() - assert.NotNil(t, err) - assert.Equal(t, "read oidc token file failed: open /path/to/tokenFilePath: no such file or directory", err.Error()) - assert.Nil(t, c) - - accessKeyId, err := oidcCredential.GetAccessKeyId() - assert.NotNil(t, err) - assert.Equal(t, "read oidc token file failed: open /path/to/tokenFilePath: no such file or directory", err.Error()) - assert.Nil(t, accessKeyId) - - accessKeySecret, err := oidcCredential.GetAccessKeySecret() - assert.NotNil(t, err) - assert.Equal(t, "read oidc token file failed: open /path/to/tokenFilePath: no such file or directory", err.Error()) - assert.Nil(t, accessKeySecret) - - securityToken, err := oidcCredential.GetSecurityToken() - assert.NotNil(t, err) - assert.Equal(t, "read oidc token file failed: open /path/to/tokenFilePath: no such file or directory", err.Error()) - assert.Nil(t, securityToken) - - originGetFileContent := getFileContent - defer func() { - getFileContent = originGetFileContent - }() - getFileContent = func(filePath string) (content string, err error) { - return "token", nil - } - // mock server error - hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { - return func(req *http.Request) (*http.Response, error) { - return mockResponse(500, ``, errors.New("mock server error")) - } - } - c, err = oidcCredential.GetCredential() - assert.NotNil(t, err) - assert.Equal(t, "get sts token failed with: mock server error", err.Error()) - assert.Nil(t, c) - // mock unmarshal error - hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { - return func(req *http.Request) (*http.Response, error) { - return mockResponse(200, `invalid json`, nil) - } - } - c, err = oidcCredential.GetCredential() - assert.NotNil(t, err) - assert.Equal(t, "get sts token failed with: Json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error()) - assert.Nil(t, c) - - // mock null response - hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { - return func(req *http.Request) (*http.Response, error) { - return mockResponse(200, `null`, nil) - } - } - c, err = oidcCredential.GetCredential() - assert.NotNil(t, err) - assert.Equal(t, "get sts token failed with: credentials is empty", err.Error()) - assert.Nil(t, c) - - hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { - return func(req *http.Request) (*http.Response, error) { - return mockResponse(200, `{}`, nil) - } - } - c, err = oidcCredential.GetCredential() - assert.NotNil(t, err) - assert.Equal(t, "get sts token failed with: credentials is empty", err.Error()) - assert.Nil(t, c) - - // mock empty ak - hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { - return func(req *http.Request) (*http.Response, error) { - return mockResponse(200, `{"Credentials": {}}`, nil) - } - } - c, err = oidcCredential.GetCredential() - assert.NotNil(t, err) - assert.Equal(t, "get sts token failed with: AccessKeyId: , AccessKeySecret: , SecurityToken: , Expiration: ", err.Error()) - assert.Nil(t, c) - - // mock normal credentials - hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { - return func(req *http.Request) (*http.Response, error) { - return mockResponse(200, `{"Credentials": {"AccessKeyId":"akid","AccessKeySecret":"aksecret","SecurityToken":"ststoken","Expiration":"2006-01-02T15:04:05Z"}}`, nil) - } - } - c, err = oidcCredential.GetCredential() - assert.Nil(t, err) - assert.NotNil(t, c) - assert.Equal(t, "akid", *c.AccessKeyId) - assert.Equal(t, "aksecret", *c.AccessKeySecret) - assert.Equal(t, "ststoken", *c.SecurityToken) - - akid, err := oidcCredential.GetAccessKeyId() - assert.Nil(t, err) - assert.Equal(t, "akid", *akid) - - secret, err := oidcCredential.GetAccessKeySecret() - assert.Nil(t, err) - assert.Equal(t, "aksecret", *secret) - - ststoken, err := oidcCredential.GetSecurityToken() - assert.Nil(t, err) - assert.Equal(t, "ststoken", *ststoken) -} - -func TestOIDCCredentialsProviderGetBearerToken(t *testing.T) { - provider, err := newOIDCRoleArnCredential("accessKeyId", "accessKeySecret", "RoleArn", "OIDCProviderArn", "tokenFilePath", "roleSessionName", "Policy", 3600, nil) - assert.Nil(t, err) - assert.Equal(t, "", *provider.GetBearerToken()) -} - -func TestOIDCCredentialsProviderGetType(t *testing.T) { - provider, err := newOIDCRoleArnCredential("accessKeyId", "accessKeySecret", "RoleArn", "OIDCProviderArn", "tokenFilePath", "roleSessionName", "Policy", 3600, nil) - assert.Nil(t, err) - assert.Equal(t, "oidc_role_arn", *provider.GetType()) -} - -func Test_getFileContent(t *testing.T) { - wd, _ := os.Getwd() - // read a normal token - token, err := getFileContent(path.Join(wd, "../test_fixtures/oidc_token")) - assert.Nil(t, err) - assert.Equal(t, "test_long_oidc_token_eyJhbGciOiJSUzI1NiIsImtpZCI6ImFQaXlpNEVGSU8wWnlGcFh1V0psQUNWbklZVlJsUkNmM2tlSzNMUlhWT1UifQ.eyJhdWQiOlsic3RzLmFsaXl1bmNzLmNvbSJdLCJleHAiOjE2NDUxMTk3ODAsImlhdCI6MTY0NTA4Mzc4MCwiaXNzIjoiaHR0cHM6Ly9vaWRjLWFjay1jbi1oYW5nemhvdS5vc3MtY24taGFuZ3pob3UtaW50ZXJuYWwuYWxpeXVuY3MuY29tL2NmMWQ4ZGIwMjM0ZDk0YzEyOGFiZDM3MTc4NWJjOWQxNSIsImt1YmVybmV0ZXMuaW8iOnsibmFtZXNwYWNlIjoidGVzdC1ycnNhIiwicG9kIjp7Im5hbWUiOiJydW4tYXMtcm9vdCIsInVpZCI6ImIzMGI0MGY2LWNiZTAtNGY0Yy1hZGYyLWM1OGQ4ZmExZTAxMCJ9LCJzZXJ2aWNlYWNjb3VudCI6eyJuYW1lIjoidXNlcjEiLCJ1aWQiOiJiZTEyMzdjYS01MTY4LTQyMzYtYWUyMC00NDM1YjhmMGI4YzAifX0sIm5iZiI6MTY0NTA4Mzc4MCwic3ViIjoic3lzdGVtOnNlcnZpY2VhY2NvdW50OnRlc3QtcnJzYTp1c2VyMSJ9.XGP-wgLj-iMiAHjLe0lZLh7y48Qsj9HzsEbNh706WwerBoxnssdsyGFb9lzd2FyM8CssbAOCstr7OuAMWNdJmDZgpiOGGSbQ-KXXmbfnIS4ix-V3pQF6LVBFr7xJlj20J6YY89um3rv_04t0iCGxKWs2ZMUyU1FbZpIPRep24LVKbUz1saiiVGgDBTIZdHA13Z-jUvYAnsxK_Kj5tc1K-IuQQU0IwSKJh5OShMcdPugMV5LwTL3ogCikfB7yljq5vclBhCeF2lXLIibvwF711TOhuJ5lMlh-a2KkIgwBHhANg_U9k4Mt_VadctfUGc4hxlSbBD0w9o9mDGKwgGmW5Q", token) - - // read a empty token - _, err = getFileContent(path.Join(wd, "../test_fixtures/empty_oidc_token")) - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "the content of ") - assert.Contains(t, err.Error(), "/test_fixtures/empty_oidc_token is empty") - - // read a inexist token - _, err = getFileContent(path.Join(wd, "../test_fixtures/inexist_oidc_token")) - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "no such file or directory") -} - -func TestSTSEndpoint(t *testing.T) { - originGetFileContent := getFileContent - defer func() { - getFileContent = originGetFileContent - }() - getFileContent = func(filePath string) (content string, err error) { - return "token", nil - } - // mock server error - hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) { - return func(req *http.Request) (*http.Response, error) { - assert.Equal(t, "sts.cn-beijing.aliyuncs.com", req.Host) - return mockResponse(500, ``, errors.New("mock server error")) - } - } - - runtime := &utils.Runtime{ - STSEndpoint: "sts.cn-beijing.aliyuncs.com", - } - provider, err := newOIDCRoleArnCredential("accessKeyId", "accessKeySecret", "RoleArn", "OIDCProviderArn", "tokenFilePath", "roleSessionName", "Policy", 3600, runtime) - assert.Nil(t, err) - c, err := provider.GetCredential() - assert.NotNil(t, err) - assert.Equal(t, "get sts token failed with: mock server error", err.Error()) - assert.Nil(t, c) -}