Skip to content

Commit

Permalink
refine test cases for OIDC
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonTian committed Aug 23, 2024
1 parent 3bd38fa commit b103b61
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 175 deletions.
7 changes: 0 additions & 7 deletions credentials/internal/providers/ecs_ram_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ import (
"time"

httputil "github.com/aliyun/credentials-go/credentials/internal/http"
"github.com/aliyun/credentials-go/credentials/internal/utils"
)

type ECSRAMRoleCredentialsProvider struct {
roleName string
metadataTokenDurationSeconds int
enableIMDSv2 bool
runtime *utils.Runtime
// for sts
session *sessionCredentials
expirationTimestamp int64
Expand Down Expand Up @@ -68,11 +66,6 @@ func (builder *ECSRAMRoleCredentialsProviderBuilder) Build() (provider *ECSRAMRo
return
}

builder.provider.runtime = &utils.Runtime{
ConnectTimeout: 5,
ReadTimeout: 5,
}

provider = builder.provider
return
}
Expand Down
77 changes: 21 additions & 56 deletions credentials/internal/providers/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"

httputil "github.com/aliyun/credentials-go/credentials/internal/http"
"github.com/aliyun/credentials-go/credentials/internal/utils"
)

Expand Down Expand Up @@ -139,13 +138,25 @@ func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvi
}

func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCredentials, err error) {
method := "POST"
host := provider.stsEndpoint
req := &httputil.Request{
Method: "POST",
Protocol: "https",
Host: provider.stsEndpoint,
Headers: map[string]string{},
}

if provider.httpOptions != nil {
req.ConnectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Second
req.ReadTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Second
req.Proxy = provider.httpOptions.Proxy
}

queries := make(map[string]string)
queries["Version"] = "2015-04-01"
queries["Action"] = "AssumeRoleWithOIDC"
queries["Format"] = "JSON"
queries["Timestamp"] = utils.GetTimeInFormatISO8601()
req.Queries = queries

bodyForm := make(map[string]string)
bodyForm["RoleArn"] = provider.roleArn
Expand All @@ -162,68 +173,22 @@ func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCrede

bodyForm["RoleSessionName"] = provider.roleSessionName
bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds)

// caculate signature
signParams := make(map[string]string)
for key, value := range queries {
signParams[key] = value
}
for key, value := range bodyForm {
signParams[key] = value
}

querystring := utils.GetURLFormedMap(queries)
// do request
httpUrl := fmt.Sprintf("https://%s/?%s", host, querystring)

body := utils.GetURLFormedMap(bodyForm)

httpRequest, err := hookNewRequest(http.NewRequest)(method, httpUrl, strings.NewReader(body))
if err != nil {
return
}
req.Form = bodyForm

// set headers
httpRequest.Header["Accept-Encoding"] = []string{"identity"}
httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"}
httpClient := &http.Client{}

if provider.httpOptions != nil {
httpClient.Timeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Second
proxy := &url.URL{}
if provider.httpOptions.Proxy != "" {
proxy, err = url.Parse(provider.httpOptions.Proxy)
if err != nil {
return
}
}
trans := &http.Transport{}
if proxy != nil && provider.httpOptions.Proxy != "" {
trans.Proxy = http.ProxyURL(proxy)
}
trans.DialContext = utils.Timeout(time.Duration(provider.httpOptions.ConnectTimeout) * time.Second)
httpClient.Transport = trans
}

httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
}

defer httpResponse.Body.Close()

responseBody, err := ioutil.ReadAll(httpResponse.Body)
req.Headers["Accept-Encoding"] = "identity"
res, err := httpDo(req)
if err != nil {
return
}

if httpResponse.StatusCode != http.StatusOK {
if res.StatusCode != http.StatusOK {
message := "get session token failed: "
err = errors.New(message + string(responseBody))
err = errors.New(message + string(res.Body))
return
}
var data assumeRoleResponse
err = json.Unmarshal(responseBody, &data)
err = json.Unmarshal(res.Body, &data)
if err != nil {
err = fmt.Errorf("get oidc sts token err, json.Unmarshal fail: %s", err.Error())
return
Expand Down
Loading

0 comments on commit b103b61

Please sign in to comment.