Skip to content

Commit

Permalink
refine ecs ram role credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonTian authored and yndu13 committed Aug 23, 2024
1 parent b5116a2 commit 3bd38fa
Show file tree
Hide file tree
Showing 6 changed files with 538 additions and 377 deletions.
135 changes: 135 additions & 0 deletions credentials/internal/http/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package http

import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"

"github.com/alibabacloud-go/debug/debug"
"github.com/aliyun/credentials-go/credentials/internal/utils"
)

type Request struct {
Method string // http request method
Protocol string // http or https
Host string // http host
ReadTimeout time.Duration
ConnectTimeout time.Duration
Proxy string // http proxy
Form map[string]string // http form
Body []byte // request body for JSON or stream
Path string
Queries map[string]string
Headers map[string]string
}

func (req *Request) BuildRequestURL() string {
httpUrl := fmt.Sprintf("%s://%s%s", req.Protocol, req.Host, req.Path)

querystring := utils.GetURLFormedMap(req.Queries)
if querystring != "" {
httpUrl = httpUrl + "?" + querystring
}

return fmt.Sprintf("%s %s", req.Method, httpUrl)
}

type Response struct {
StatusCode int
Headers map[string]string
Body []byte
}

var newRequest = http.NewRequest

type do func(req *http.Request) (*http.Response, error)

var hookDo = func(fn do) do {
return fn
}

var debuglog = debug.Init("credential")

func Do(req *Request) (res *Response, err error) {
querystring := utils.GetURLFormedMap(req.Queries)
// do request
httpUrl := fmt.Sprintf("%s://%s%s?%s", req.Protocol, req.Host, req.Path, querystring)

var body io.Reader
if req.Method == "GET" {
body = strings.NewReader("")
} else {
body = strings.NewReader(utils.GetURLFormedMap(req.Form))
}

httpRequest, err := newRequest(req.Method, httpUrl, body)
if err != nil {
return
}

if req.Form != nil {
httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"}
}

for key, value := range req.Headers {
if value != "" {
debuglog("> %s: %s", key, value)
httpRequest.Header.Set(key, value)
}
}

httpClient := &http.Client{}

if req.ReadTimeout != 0 {
httpClient.Timeout = req.ReadTimeout
}

transport := &http.Transport{}
if req.Proxy != "" {
var proxy *url.URL
proxy, err = url.Parse(req.Proxy)
if err != nil {
return
}
transport.Proxy = http.ProxyURL(proxy)
}

if req.ConnectTimeout != 0 {
transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{
Timeout: req.ConnectTimeout,
DualStack: true,
}).DialContext(ctx, network, address)
}
}

httpClient.Transport = transport

httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
}

defer httpResponse.Body.Close()

responseBody, err := ioutil.ReadAll(httpResponse.Body)
if err != nil {
return
}
res = &Response{
StatusCode: httpResponse.StatusCode,
Headers: make(map[string]string),
Body: responseBody,
}
for key, v := range httpResponse.Header {
res.Headers[key] = v[0]
}

return
}
189 changes: 189 additions & 0 deletions credentials/internal/http/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package http

import (
"errors"
"io"
"io/ioutil"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestRequest(t *testing.T) {
req := &Request{
Method: "GET",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
}
assert.Equal(t, "GET http://www.aliyun.com/", req.BuildRequestURL())
// With query
req = &Request{
Method: "GET",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Queries: map[string]string{
"spm": "test",
},
}
assert.Equal(t, "GET http://www.aliyun.com/?spm=test", req.BuildRequestURL())
}

func TestDoGet(t *testing.T) {
req := &Request{
Method: "GET",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
}
res, err := Do(req)
assert.Nil(t, err)
assert.NotNil(t, res)
assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "text/html; charset=utf-8", res.Headers["Content-Type"])
}

func TestDoPost(t *testing.T) {
req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Form: map[string]string{
"URL": "HI",
},
Headers: map[string]string{
"Accept-Language": "zh",
},
}
res, err := Do(req)
assert.Nil(t, err)
assert.NotNil(t, res)
assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "text/html; charset=utf-8", res.Headers["Content-Type"])
}

type errorReader struct {
}

func (r *errorReader) Read(p []byte) (n int, err error) {
err = errors.New("read failed")
return
}

func TestDoWithError(t *testing.T) {
originNewRequest := newRequest
defer func() { newRequest = originNewRequest }()

// case 1: mock new http request failed
newRequest = func(method, url string, body io.Reader) (*http.Request, error) {
return nil, errors.New("new http request failed")
}

req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Form: map[string]string{
"URL": "HI",
},
Headers: map[string]string{
"Accept-Language": "zh",
},
}
_, err := Do(req)
assert.EqualError(t, err, "new http request failed")

// reset new request
newRequest = originNewRequest

// case 2: server error
originDo := hookDo
defer func() { hookDo = originDo }()
hookDo = func(fn do) do {
return func(req *http.Request) (res *http.Response, err error) {
err = errors.New("mock server error")
return
}
}
_, err = Do(req)
assert.EqualError(t, err, "mock server error")

// case 4: mock read response error
hookDo = func(fn do) do {
return func(req *http.Request) (res *http.Response, err error) {
res = &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: map[string][]string{},
StatusCode: 200,
Status: "200 " + http.StatusText(200),
}
res.Body = ioutil.NopCloser(&errorReader{})
return
}
}

_, err = Do(req)
assert.EqualError(t, err, "read failed")
}

func TestDoWithProxy(t *testing.T) {
req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Form: map[string]string{
"URL": "HI",
},
Headers: map[string]string{
"Accept-Language": "zh",
},
Proxy: "http://localhost:9999/",
}
_, err := Do(req)
assert.Contains(t, err.Error(), "proxyconnect tcp: dial tcp")
assert.Contains(t, err.Error(), "connect: connection refused")

// invalid proxy url
req.Proxy = string([]byte{0x7f})
_, err = Do(req)
assert.Contains(t, err.Error(), "net/url: invalid control character in URL")
}

func TestDoWithConnectTimeout(t *testing.T) {
req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Form: map[string]string{
"URL": "HI",
},
Headers: map[string]string{
"Accept-Language": "zh",
},
ConnectTimeout: 1 * time.Nanosecond,
}
_, err := Do(req)
assert.Contains(t, err.Error(), "dial tcp: ")
assert.Contains(t, err.Error(), "i/o timeout")
}

func TestDoWithReadTimeout(t *testing.T) {
req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
ReadTimeout: 1 * time.Nanosecond,
}
_, err := Do(req)
assert.Contains(t, err.Error(), "(Client.Timeout exceeded while awaiting headers)")
}
Loading

0 comments on commit 3bd38fa

Please sign in to comment.