diff --git a/credentials/internal/http/http.go b/credentials/internal/http/http.go new file mode 100644 index 0000000..4c97122 --- /dev/null +++ b/credentials/internal/http/http.go @@ -0,0 +1,135 @@ +package http + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/alibabacloud-go/debug/debug" + "github.com/aliyun/credentials-go/credentials/internal/utils" +) + +type Request struct { + Method string // http request method + Protocol string // http or https + Host string // http host + ReadTimeout time.Duration + ConnectTimeout time.Duration + Proxy string // http proxy + Form map[string]string // http form + Body []byte // request body for JSON or stream + Path string + Queries map[string]string + Headers map[string]string +} + +func (req *Request) BuildRequestURL() string { + httpUrl := fmt.Sprintf("%s://%s%s", req.Protocol, req.Host, req.Path) + + querystring := utils.GetURLFormedMap(req.Queries) + if querystring != "" { + httpUrl = httpUrl + "?" + querystring + } + + return fmt.Sprintf("%s %s", req.Method, httpUrl) +} + +type Response struct { + StatusCode int + Headers map[string]string + Body []byte +} + +var newRequest = http.NewRequest + +type do func(req *http.Request) (*http.Response, error) + +var hookDo = func(fn do) do { + return fn +} + +var debuglog = debug.Init("credential") + +func Do(req *Request) (res *Response, err error) { + querystring := utils.GetURLFormedMap(req.Queries) + // do request + httpUrl := fmt.Sprintf("%s://%s%s?%s", req.Protocol, req.Host, req.Path, querystring) + + var body io.Reader + if req.Method == "GET" { + body = strings.NewReader("") + } else { + body = strings.NewReader(utils.GetURLFormedMap(req.Form)) + } + + httpRequest, err := newRequest(req.Method, httpUrl, body) + if err != nil { + return + } + + if req.Form != nil { + httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"} + } + + for key, value := range req.Headers { + if value != "" { + debuglog("> %s: %s", key, value) + httpRequest.Header.Set(key, value) + } + } + + httpClient := &http.Client{} + + if req.ReadTimeout != 0 { + httpClient.Timeout = req.ReadTimeout + } + + transport := &http.Transport{} + if req.Proxy != "" { + var proxy *url.URL + proxy, err = url.Parse(req.Proxy) + if err != nil { + return + } + transport.Proxy = http.ProxyURL(proxy) + } + + if req.ConnectTimeout != 0 { + transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + return (&net.Dialer{ + Timeout: req.ConnectTimeout, + DualStack: true, + }).DialContext(ctx, network, address) + } + } + + httpClient.Transport = transport + + httpResponse, err := hookDo(httpClient.Do)(httpRequest) + if err != nil { + return + } + + defer httpResponse.Body.Close() + + responseBody, err := ioutil.ReadAll(httpResponse.Body) + if err != nil { + return + } + res = &Response{ + StatusCode: httpResponse.StatusCode, + Headers: make(map[string]string), + Body: responseBody, + } + for key, v := range httpResponse.Header { + res.Headers[key] = v[0] + } + + return +} diff --git a/credentials/internal/http/http_test.go b/credentials/internal/http/http_test.go new file mode 100644 index 0000000..ad2d067 --- /dev/null +++ b/credentials/internal/http/http_test.go @@ -0,0 +1,189 @@ +package http + +import ( + "errors" + "io" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRequest(t *testing.T) { + req := &Request{ + Method: "GET", + Protocol: "http", + Host: "www.aliyun.com", + Path: "/", + } + assert.Equal(t, "GET http://www.aliyun.com/", req.BuildRequestURL()) + // With query + req = &Request{ + Method: "GET", + Protocol: "http", + Host: "www.aliyun.com", + Path: "/", + Queries: map[string]string{ + "spm": "test", + }, + } + assert.Equal(t, "GET http://www.aliyun.com/?spm=test", req.BuildRequestURL()) +} + +func TestDoGet(t *testing.T) { + req := &Request{ + Method: "GET", + Protocol: "http", + Host: "www.aliyun.com", + Path: "/", + } + res, err := Do(req) + assert.Nil(t, err) + assert.NotNil(t, res) + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "text/html; charset=utf-8", res.Headers["Content-Type"]) +} + +func TestDoPost(t *testing.T) { + req := &Request{ + Method: "POST", + Protocol: "http", + Host: "www.aliyun.com", + Path: "/", + Form: map[string]string{ + "URL": "HI", + }, + Headers: map[string]string{ + "Accept-Language": "zh", + }, + } + res, err := Do(req) + assert.Nil(t, err) + assert.NotNil(t, res) + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "text/html; charset=utf-8", res.Headers["Content-Type"]) +} + +type errorReader struct { +} + +func (r *errorReader) Read(p []byte) (n int, err error) { + err = errors.New("read failed") + return +} + +func TestDoWithError(t *testing.T) { + originNewRequest := newRequest + defer func() { newRequest = originNewRequest }() + + // case 1: mock new http request failed + newRequest = func(method, url string, body io.Reader) (*http.Request, error) { + return nil, errors.New("new http request failed") + } + + req := &Request{ + Method: "POST", + Protocol: "http", + Host: "www.aliyun.com", + Path: "/", + Form: map[string]string{ + "URL": "HI", + }, + Headers: map[string]string{ + "Accept-Language": "zh", + }, + } + _, err := Do(req) + assert.EqualError(t, err, "new http request failed") + + // reset new request + newRequest = originNewRequest + + // case 2: server error + originDo := hookDo + defer func() { hookDo = originDo }() + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + err = errors.New("mock server error") + return + } + } + _, err = Do(req) + assert.EqualError(t, err, "mock server error") + + // case 4: mock read response error + hookDo = func(fn do) do { + return func(req *http.Request) (res *http.Response, err error) { + res = &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: map[string][]string{}, + StatusCode: 200, + Status: "200 " + http.StatusText(200), + } + res.Body = ioutil.NopCloser(&errorReader{}) + return + } + } + + _, err = Do(req) + assert.EqualError(t, err, "read failed") +} + +func TestDoWithProxy(t *testing.T) { + req := &Request{ + Method: "POST", + Protocol: "http", + Host: "www.aliyun.com", + Path: "/", + Form: map[string]string{ + "URL": "HI", + }, + Headers: map[string]string{ + "Accept-Language": "zh", + }, + Proxy: "http://localhost:9999/", + } + _, err := Do(req) + assert.Contains(t, err.Error(), "proxyconnect tcp: dial tcp") + assert.Contains(t, err.Error(), "connect: connection refused") + + // invalid proxy url + req.Proxy = string([]byte{0x7f}) + _, err = Do(req) + assert.Contains(t, err.Error(), "net/url: invalid control character in URL") +} + +func TestDoWithConnectTimeout(t *testing.T) { + req := &Request{ + Method: "POST", + Protocol: "http", + Host: "www.aliyun.com", + Path: "/", + Form: map[string]string{ + "URL": "HI", + }, + Headers: map[string]string{ + "Accept-Language": "zh", + }, + ConnectTimeout: 1 * time.Nanosecond, + } + _, err := Do(req) + assert.Contains(t, err.Error(), "dial tcp: ") + assert.Contains(t, err.Error(), "i/o timeout") +} + +func TestDoWithReadTimeout(t *testing.T) { + req := &Request{ + Method: "POST", + Protocol: "http", + Host: "www.aliyun.com", + Path: "/", + ReadTimeout: 1 * time.Nanosecond, + } + _, err := Do(req) + assert.Contains(t, err.Error(), "(Client.Timeout exceeded while awaiting headers)") +} diff --git a/credentials/internal/providers/ecs_ram_role.go b/credentials/internal/providers/ecs_ram_role.go index 25fd105..b1fde39 100644 --- a/credentials/internal/providers/ecs_ram_role.go +++ b/credentials/internal/providers/ecs_ram_role.go @@ -4,20 +4,15 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" - "net/http" "os" "strconv" "strings" "time" + httputil "github.com/aliyun/credentials-go/credentials/internal/http" "github.com/aliyun/credentials-go/credentials/internal/utils" - "github.com/aliyun/credentials-go/credentials/request" ) -const securityCredURL = "http://100.100.100.200/latest/meta-data/ram/security-credentials/" -const apiTokenURL = "http://100.100.100.200/latest/api/token" - type ECSRAMRoleCredentialsProvider struct { roleName string metadataTokenDurationSeconds int @@ -100,10 +95,14 @@ func (provider *ECSRAMRoleCredentialsProvider) needUpdateCredential() bool { } func (provider *ECSRAMRoleCredentialsProvider) getRoleName() (roleName string, err error) { - httpRequest, err := hookNewRequest(http.NewRequest)("GET", securityCredURL, strings.NewReader("")) - if err != nil { - err = fmt.Errorf("get role name failed: %s", err.Error()) - return + req := &httputil.Request{ + Method: "GET", + Protocol: "http", + Host: "100.100.100.200", + Path: "/latest/meta-data/ram/security-credentials/", + ConnectTimeout: 5 * time.Second, + ReadTimeout: 5 * time.Second, + Headers: map[string]string{}, } if provider.enableIMDSv2 { @@ -111,31 +110,21 @@ func (provider *ECSRAMRoleCredentialsProvider) getRoleName() (roleName string, e if err != nil { return "", err } - httpRequest.Header.Set("x-aliyun-ecs-metadata-token", metadataToken) + req.Headers["x-aliyun-ecs-metadata-token"] = metadataToken } - httpClient := &http.Client{ - Timeout: 5 * time.Second, - } - httpResponse, err := hookDo(httpClient.Do)(httpRequest) + res, err := httpDo(req) if err != nil { err = fmt.Errorf("get role name failed: %s", err.Error()) return } - if httpResponse.StatusCode != http.StatusOK { - err = fmt.Errorf("get role name failed: request %s %d", securityCredURL, httpResponse.StatusCode) + if res.StatusCode != 200 { + err = fmt.Errorf("get role name failed: %s %d", req.BuildRequestURL(), res.StatusCode) return } - defer httpResponse.Body.Close() - - responseBody, err := ioutil.ReadAll(httpResponse.Body) - if err != nil { - return - } - - roleName = strings.TrimSpace(string(responseBody)) + roleName = strings.TrimSpace(string(res.Body)) return } @@ -148,11 +137,14 @@ func (provider *ECSRAMRoleCredentialsProvider) getCredentials() (session *sessio } } - requestUrl := securityCredURL + roleName - httpRequest, err := hookNewRequest(http.NewRequest)("GET", requestUrl, strings.NewReader("")) - if err != nil { - err = fmt.Errorf("refresh Ecs sts token err: %s", err.Error()) - return + req := &httputil.Request{ + Method: "GET", + Protocol: "http", + Host: "100.100.100.200", + Path: "/latest/meta-data/ram/security-credentials/" + roleName, + ConnectTimeout: 5 * time.Second, + ReadTimeout: 5 * time.Second, + Headers: map[string]string{}, } if provider.enableIMDSv2 { @@ -160,32 +152,22 @@ func (provider *ECSRAMRoleCredentialsProvider) getCredentials() (session *sessio if err != nil { return nil, err } - httpRequest.Header.Set("x-aliyun-ecs-metadata-token", metadataToken) + req.Headers["x-aliyun-ecs-metadata-token"] = metadataToken } - httpClient := &http.Client{ - Timeout: 5 * time.Second, - } - httpResponse, err := hookDo(httpClient.Do)(httpRequest) + res, err := httpDo(req) if err != nil { err = fmt.Errorf("refresh Ecs sts token err: %s", err.Error()) return } - defer httpResponse.Body.Close() - - responseBody, err := ioutil.ReadAll(httpResponse.Body) - if err != nil { - return - } - - if httpResponse.StatusCode != http.StatusOK { - err = fmt.Errorf("refresh Ecs sts token err, httpStatus: %d, message = %s", httpResponse.StatusCode, string(responseBody)) + if res.StatusCode != 200 { + err = fmt.Errorf("refresh Ecs sts token err, httpStatus: %d, message = %s", res.StatusCode, string(res.Body)) return } var data ecsRAMRoleResponse - err = json.Unmarshal(responseBody, &data) + err = json.Unmarshal(res.Body, &data) if err != nil { err = fmt.Errorf("refresh Ecs sts token err, json.Unmarshal fail: %s", err.Error()) return @@ -239,15 +221,23 @@ func (provider *ECSRAMRoleCredentialsProvider) GetProviderName() string { } func (provider *ECSRAMRoleCredentialsProvider) getMetadataToken() (metadataToken string, err error) { - request := request.NewCommonRequest() - request.URL = apiTokenURL - request.Method = "PUT" - request.Headers["X-aliyun-ecs-metadata-token-ttl-seconds"] = strconv.Itoa(provider.metadataTokenDurationSeconds) - content, err := doAction(request, provider.runtime) + // PUT http://100.100.100.200/latest/api/token + req := &httputil.Request{ + Method: "PUT", + Protocol: "http", + Host: "100.100.100.200", + Path: "/latest/api/token", + Headers: map[string]string{ + "X-aliyun-ecs-metadata-token-ttl-seconds": strconv.Itoa(provider.metadataTokenDurationSeconds), + }, + ConnectTimeout: 5 * time.Second, + ReadTimeout: 5 * time.Second, + } + res, err := httpDo(req) if err != nil { err = fmt.Errorf("get metadata token failed: %s", err.Error()) return } - metadataToken = string(content) + metadataToken = string(res.Body) return } diff --git a/credentials/internal/providers/ecs_ram_role_test.go b/credentials/internal/providers/ecs_ram_role_test.go index 36f03bd..d4cf538 100644 --- a/credentials/internal/providers/ecs_ram_role_test.go +++ b/credentials/internal/providers/ecs_ram_role_test.go @@ -2,13 +2,10 @@ package providers import ( "errors" - "io" - "io/ioutil" - "net/http" - "strconv" "testing" "time" + httputil "github.com/aliyun/credentials-go/credentials/internal/http" "github.com/stretchr/testify/assert" ) @@ -30,75 +27,42 @@ func TestNewECSRAMRoleCredentialsProvider(t *testing.T) { } func TestECSRAMRoleCredentialsProvider_getRoleName(t *testing.T) { + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + p, err := NewECSRAMRoleCredentialsProviderBuilder().Build() assert.Nil(t, err) - originNewRequest := hookNewRequest - defer func() { hookNewRequest = originNewRequest }() - - // case 1: mock new http request failed - hookNewRequest = func(fn newReuqest) newReuqest { - return func(method, url string, body io.Reader) (*http.Request, error) { - return nil, errors.New("new http request failed") - } + // case 1: server error + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } - _, err = p.getRoleName() - assert.NotNil(t, err) - assert.Equal(t, "get role name failed: new http request failed", err.Error()) - // reset new request - hookNewRequest = originNewRequest - - originDo := hookDo - defer func() { hookDo = originDo }() - // case 2: server error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } - } _, err = p.getRoleName() assert.NotNil(t, err) assert.Equal(t, "get role name failed: mock server error", err.Error()) - // case 3: 4xx error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(400, "4xx error") - return + // case 2: 4xx error + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 400, + Body: []byte("4xx error"), } + return } _, err = p.getRoleName() assert.NotNil(t, err) - assert.Equal(t, "get role name failed: request http://100.100.100.200/latest/meta-data/ram/security-credentials/ 400", err.Error()) - - // case 4: mock read response error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - status := strconv.Itoa(200) - res = &http.Response{ - Proto: "HTTP/1.1", - ProtoMajor: 1, - Header: map[string][]string{}, - StatusCode: 200, - Status: status + " " + http.StatusText(200), - } - res.Body = ioutil.NopCloser(&errorReader{}) - return - } - } - _, err = p.getRoleName() - assert.NotNil(t, err) - assert.Equal(t, "read failed", err.Error()) + assert.Equal(t, "get role name failed: GET http://100.100.100.200/latest/meta-data/ram/security-credentials/ 400", err.Error()) - // case 5: value json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, "rolename") - return + // case 3: ok + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } + return } roleName, err := p.getRoleName() assert.Nil(t, err) @@ -106,34 +70,38 @@ func TestECSRAMRoleCredentialsProvider_getRoleName(t *testing.T) { } func TestECSRAMRoleCredentialsProvider_getRoleNameWithMetadataV2(t *testing.T) { + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() + p, err := NewECSRAMRoleCredentialsProviderBuilder().WithEnableIMDSv2(true).Build() assert.Nil(t, err) // case 1: get metadata token failed - originDo := hookDo - defer func() { hookDo = originDo }() - - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } + _, err = p.getRoleName() assert.NotNil(t, err) assert.Equal(t, "get metadata token failed: mock server error", err.Error()) // case 2: return token - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/api/token" { - res = mockResponse(200, `tokenxxxxx`) - } else { - assert.Equal(t, "tokenxxxxx", req.Header.Get("x-aliyun-ecs-metadata-token")) - res = mockResponse(200, "rolename") + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/api/token" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("tokenxxxxx"), + } + } else { + assert.Equal(t, "tokenxxxxx", req.Headers["x-aliyun-ecs-metadata-token"]) + + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } - return } + return } roleName, err := p.getRoleName() @@ -142,209 +110,156 @@ func TestECSRAMRoleCredentialsProvider_getRoleNameWithMetadataV2(t *testing.T) { } func TestECSRAMRoleCredentialsProvider_getCredentials(t *testing.T) { - originDo := hookDo - defer func() { hookDo = originDo }() + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() p, err := NewECSRAMRoleCredentialsProviderBuilder().Build() assert.Nil(t, err) // case 1: server error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "get role name failed: mock server error", err.Error()) - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return + // case 2: get role name ok, get credentials failed with server error + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/meta-data/ram/security-credentials/" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } - - err = errors.New("mock server error") return } - } - - originNewRequest := hookNewRequest - defer func() { hookNewRequest = originNewRequest }() - - // case 2: mock new http request failed - hookNewRequest = func(fn newReuqest) newReuqest { - return func(method, url string, body io.Reader) (*http.Request, error) { - if url == "http://100.100.100.200/latest/meta-data/ram/security-credentials/rolename" { - return nil, errors.New("new http request failed") - } - return http.NewRequest(method, url, body) - } + err = errors.New("mock server error") + return } - _, err = p.getCredentials() - assert.NotNil(t, err) - assert.Equal(t, "refresh Ecs sts token err: new http request failed", err.Error()) - - hookNewRequest = originNewRequest - - // case 3 - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return - } - - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - err = errors.New("mock server error") - return - } - return - } - } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "refresh Ecs sts token err: mock server error", err.Error()) - // case 4: mock read response error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return - } - - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - status := strconv.Itoa(200) - res = &http.Response{ - Proto: "HTTP/1.1", - ProtoMajor: 1, - Header: map[string][]string{}, - StatusCode: 200, - Status: status + " " + http.StatusText(200), - } - res.Body = ioutil.NopCloser(&errorReader{}) - return + // case 3: 4xx error + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/meta-data/ram/security-credentials/" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } return } - } - _, err = p.getCredentials() - assert.NotNil(t, err) - assert.Equal(t, "read failed", err.Error()) - - // case 4: 4xx error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return - } - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - res = mockResponse(400, "4xx error") - return - } - return + res = &httputil.Response{ + StatusCode: 400, + Body: []byte("4xx error"), } + return } + _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "refresh Ecs sts token err, httpStatus: 400, message = 4xx error", err.Error()) - // case 5: invalid json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return - } - - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - res = mockResponse(200, "invalid json") - return + // case 4: invalid json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/meta-data/ram/security-credentials/" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } return } + + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("invalid json"), + } + return } + _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "refresh Ecs sts token err, json.Unmarshal fail: invalid character 'i' looking for beginning of value", err.Error()) - // case 6: empty response json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return - } - - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - res = mockResponse(200, "null") - return + // case 5: empty response json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/meta-data/ram/security-credentials/" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } return } + + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("null"), + } + return } + _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "refresh Ecs sts token err, fail to get credentials", err.Error()) - // case 7: empty session ak response json - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return - } - - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - res = mockResponse(200, `{}`) - return + // case 6: empty session ak response json + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/meta-data/ram/security-credentials/" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } return } + + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("{}"), + } + return } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "refresh Ecs sts token err, fail to get credentials", err.Error()) - // case 8: non-success response - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return - } - - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - res = mockResponse(200, `{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Failed"}`) - return + // case 7: non-success response + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/meta-data/ram/security-credentials/" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } return } + + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Failed"}`), + } + return } _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "refresh Ecs sts token err, Code is not Success", err.Error()) // case 8: mock ok value - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/meta-data/ram/security-credentials/" { - res = mockResponse(200, "rolename") - return - } - - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - res = mockResponse(200, `{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/meta-data/ram/security-credentials/" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("rolename"), } return } + + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`), + } + return } creds, err := p.getCredentials() assert.Nil(t, err) @@ -363,36 +278,37 @@ func TestECSRAMRoleCredentialsProvider_getCredentials(t *testing.T) { } func TestECSRAMRoleCredentialsProvider_getCredentialsWithMetadataV2(t *testing.T) { - originDo := hookDo - defer func() { hookDo = originDo }() + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() p, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("rolename").WithEnableIMDSv2(true).Build() assert.Nil(t, err) // case 1: get metadata token failed - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } + _, err = p.getCredentials() assert.NotNil(t, err) assert.Equal(t, "get metadata token failed: mock server error", err.Error()) // case 2: return token - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - if req.URL.Path == "/latest/api/token" { - res = mockResponse(200, `tokenxxxxx`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + if req.Path == "/latest/api/token" { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("tokenxxxxx"), } - if req.URL.Path == "/latest/meta-data/ram/security-credentials/rolename" { - assert.Equal(t, "tokenxxxxx", req.Header.Get("x-aliyun-ecs-metadata-token")) - res = mockResponse(200, `{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`) + } else if req.Path == "/latest/meta-data/ram/security-credentials/rolename" { + assert.Equal(t, "tokenxxxxx", req.Headers["x-aliyun-ecs-metadata-token"]) + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`), } - return } + return } creds, err := p.getCredentials() @@ -412,39 +328,39 @@ func TestECSRAMRoleCredentialsProvider_getCredentialsWithMetadataV2(t *testing.T } func TestECSRAMRoleCredentialsProviderGetCredentials(t *testing.T) { - originDo := hookDo - defer func() { hookDo = originDo }() + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() p, err := NewECSRAMRoleCredentialsProviderBuilder().WithRoleName("rolename").Build() assert.Nil(t, err) // case 1: get credentials failed - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } _, err = p.GetCredentials() assert.NotNil(t, err) assert.Equal(t, "refresh Ecs sts token err: mock server error", err.Error()) // case 2: get invalid expiration - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"invalidexpiration","SecurityToken":"token","Code":"Success"}`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"AccessKeyId":"saki","AccessKeySecret":"saks","Expiration":"invalidexpiration","SecurityToken":"token","Code":"Success"}`), } + return } _, err = p.GetCredentials() assert.NotNil(t, err) assert.Equal(t, "parsing time \"invalidexpiration\" as \"2006-01-02T15:04:05Z\": cannot parse \"invalidexpiration\" as \"2006\"", err.Error()) // case 3: happy result - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `{"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte(`{"AccessKeyId":"akid","AccessKeySecret":"aksecret","Expiration":"2021-10-20T04:27:09Z","SecurityToken":"token","Code":"Success"}`), } + return } cc, err := p.GetCredentials() assert.Nil(t, err) @@ -455,28 +371,28 @@ func TestECSRAMRoleCredentialsProviderGetCredentials(t *testing.T) { } func TestECSRAMRoleCredentialsProvider_getMetadataToken(t *testing.T) { - originDo := hookDo - defer func() { hookDo = originDo }() + originHttpDo := httpDo + defer func() { httpDo = originHttpDo }() p, err := NewECSRAMRoleCredentialsProviderBuilder().Build() assert.Nil(t, err) // case 1: server error - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - err = errors.New("mock server error") - return - } + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + err = errors.New("mock server error") + return } + _, err = p.getMetadataToken() assert.NotNil(t, err) assert.Equal(t, "get metadata token failed: mock server error", err.Error()) // case 2: return token - hookDo = func(fn do) do { - return func(req *http.Request) (res *http.Response, err error) { - res = mockResponse(200, `tokenxxxxx`) - return + httpDo = func(req *httputil.Request) (res *httputil.Response, err error) { + res = &httputil.Response{ + StatusCode: 200, + Body: []byte("tokenxxxxx"), } + return } metadataToken, err := p.getMetadataToken() assert.Nil(t, err) diff --git a/credentials/internal/providers/hook.go b/credentials/internal/providers/hook.go index 8248138..f09e65b 100644 --- a/credentials/internal/providers/hook.go +++ b/credentials/internal/providers/hook.go @@ -3,8 +3,12 @@ package providers import ( "io" "net/http" + + httputil "github.com/aliyun/credentials-go/credentials/internal/http" ) +var httpDo = httputil.Do + type newReuqest func(method, url string, body io.Reader) (*http.Request, error) var hookNewRequest = func(fn newReuqest) newReuqest { @@ -16,7 +20,3 @@ type do func(req *http.Request) (*http.Response, error) var hookDo = func(fn do) do { return fn } - -var hookParse = func(err error) error { - return err -} diff --git a/credentials/internal/providers/http.go b/credentials/internal/providers/http.go index 643cee6..352ed9f 100644 --- a/credentials/internal/providers/http.go +++ b/credentials/internal/providers/http.go @@ -2,22 +2,11 @@ package providers import ( "bytes" - "fmt" "io/ioutil" "net/http" - "net/url" "strconv" - "strings" - "time" - - "github.com/alibabacloud-go/debug/debug" - "github.com/aliyun/credentials-go/credentials/internal/utils" - "github.com/aliyun/credentials-go/credentials/request" - "github.com/aliyun/credentials-go/credentials/response" ) -var debuglog = debug.Init("credential") - func mockResponse(statusCode int, content string) (res *http.Response) { status := strconv.Itoa(statusCode) res = &http.Response{ @@ -30,61 +19,3 @@ func mockResponse(statusCode int, content string) (res *http.Response) { res.Body = ioutil.NopCloser(bytes.NewReader([]byte(content))) return } - -func doAction(request *request.CommonRequest, runtime *utils.Runtime) (content []byte, err error) { - var urlEncoded string - if request.BodyParams != nil { - urlEncoded = utils.GetURLFormedMap(request.BodyParams) - } - httpRequest, err := http.NewRequest(request.Method, request.URL, strings.NewReader(urlEncoded)) - if err != nil { - return - } - httpRequest.Proto = "HTTP/1.1" - httpRequest.Host = request.Domain - debuglog("> %s %s %s", httpRequest.Method, httpRequest.URL.RequestURI(), httpRequest.Proto) - debuglog("> Host: %s", httpRequest.Host) - for key, value := range request.Headers { - if value != "" { - debuglog("> %s: %s", key, value) - httpRequest.Header[key] = []string{value} - } - } - debuglog(">") - httpClient := &http.Client{} - httpClient.Timeout = time.Duration(runtime.ReadTimeout) * time.Second - proxy := &url.URL{} - if runtime.Proxy != "" { - proxy, err = url.Parse(runtime.Proxy) - if err != nil { - return - } - } - trans := &http.Transport{} - if proxy != nil && runtime.Proxy != "" { - trans.Proxy = http.ProxyURL(proxy) - } - trans.DialContext = utils.Timeout(time.Duration(runtime.ConnectTimeout) * time.Second) - httpClient.Transport = trans - httpResponse, err := hookDo(httpClient.Do)(httpRequest) - if err != nil { - return - } - debuglog("< %s %s", httpResponse.Proto, httpResponse.Status) - for key, value := range httpResponse.Header { - debuglog("< %s: %v", key, strings.Join(value, "")) - } - debuglog("<") - - resp := &response.CommonResponse{} - err = hookParse(resp.ParseFromHTTPResponse(httpResponse)) - if err != nil { - return - } - debuglog("%s", resp.GetHTTPContentString()) - if resp.GetHTTPStatus() != http.StatusOK { - err = fmt.Errorf("httpStatus: %d, message = %s", resp.GetHTTPStatus(), resp.GetHTTPContentString()) - return - } - return resp.GetHTTPContentBytes(), nil -}