Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine ecs ram role credentials provider #111

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions credentials/internal/http/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package http

import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"

"github.com/alibabacloud-go/debug/debug"
"github.com/aliyun/credentials-go/credentials/internal/utils"
)

type Request struct {
Method string // http request method
Protocol string // http or https
Host string // http host
ReadTimeout time.Duration
ConnectTimeout time.Duration
Proxy string // http proxy
Form map[string]string // http form
Body []byte // request body for JSON or stream
Path string
Queries map[string]string
Headers map[string]string
}

func (req *Request) BuildRequestURL() string {
httpUrl := fmt.Sprintf("%s://%s%s", req.Protocol, req.Host, req.Path)

querystring := utils.GetURLFormedMap(req.Queries)
if querystring != "" {
httpUrl = httpUrl + "?" + querystring
}

return fmt.Sprintf("%s %s", req.Method, httpUrl)
}

type Response struct {
StatusCode int
Headers map[string]string
Body []byte
}

var newRequest = http.NewRequest

type do func(req *http.Request) (*http.Response, error)

var hookDo = func(fn do) do {
return fn
}

var debuglog = debug.Init("credential")

func Do(req *Request) (res *Response, err error) {
querystring := utils.GetURLFormedMap(req.Queries)
// do request
httpUrl := fmt.Sprintf("%s://%s%s?%s", req.Protocol, req.Host, req.Path, querystring)

var body io.Reader
if req.Method == "GET" {
body = strings.NewReader("")
} else {
body = strings.NewReader(utils.GetURLFormedMap(req.Form))
}

httpRequest, err := newRequest(req.Method, httpUrl, body)
if err != nil {
return
}

if req.Form != nil {
httpRequest.Header["Content-Type"] = []string{"application/x-www-form-urlencoded"}
}

for key, value := range req.Headers {
if value != "" {
debuglog("> %s: %s", key, value)
httpRequest.Header.Set(key, value)
}
}

httpClient := &http.Client{}

if req.ReadTimeout != 0 {
httpClient.Timeout = req.ReadTimeout
}

transport := &http.Transport{}
if req.Proxy != "" {
var proxy *url.URL
proxy, err = url.Parse(req.Proxy)
if err != nil {
return
}
transport.Proxy = http.ProxyURL(proxy)
}

if req.ConnectTimeout != 0 {
transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{
Timeout: req.ConnectTimeout,
DualStack: true,
}).DialContext(ctx, network, address)
}
}

httpClient.Transport = transport

httpResponse, err := hookDo(httpClient.Do)(httpRequest)
if err != nil {
return
}

defer httpResponse.Body.Close()

responseBody, err := ioutil.ReadAll(httpResponse.Body)
if err != nil {
return
}
res = &Response{
StatusCode: httpResponse.StatusCode,
Headers: make(map[string]string),
Body: responseBody,
}
for key, v := range httpResponse.Header {
res.Headers[key] = v[0]
}

return
}
189 changes: 189 additions & 0 deletions credentials/internal/http/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package http

import (
"errors"
"io"
"io/ioutil"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestRequest(t *testing.T) {
req := &Request{
Method: "GET",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
}
assert.Equal(t, "GET http://www.aliyun.com/", req.BuildRequestURL())
// With query
req = &Request{
Method: "GET",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Queries: map[string]string{
"spm": "test",
},
}
assert.Equal(t, "GET http://www.aliyun.com/?spm=test", req.BuildRequestURL())
}

func TestDoGet(t *testing.T) {
req := &Request{
Method: "GET",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
}
res, err := Do(req)
assert.Nil(t, err)
assert.NotNil(t, res)
assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "text/html; charset=utf-8", res.Headers["Content-Type"])
}

func TestDoPost(t *testing.T) {
req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Form: map[string]string{
"URL": "HI",
},
Headers: map[string]string{
"Accept-Language": "zh",
},
}
res, err := Do(req)
assert.Nil(t, err)
assert.NotNil(t, res)
assert.Equal(t, 200, res.StatusCode)
assert.Equal(t, "text/html; charset=utf-8", res.Headers["Content-Type"])
}

type errorReader struct {
}

func (r *errorReader) Read(p []byte) (n int, err error) {
err = errors.New("read failed")
return
}

func TestDoWithError(t *testing.T) {
originNewRequest := newRequest
defer func() { newRequest = originNewRequest }()

// case 1: mock new http request failed
newRequest = func(method, url string, body io.Reader) (*http.Request, error) {
return nil, errors.New("new http request failed")
}

req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Form: map[string]string{
"URL": "HI",
},
Headers: map[string]string{
"Accept-Language": "zh",
},
}
_, err := Do(req)
assert.EqualError(t, err, "new http request failed")

// reset new request
newRequest = originNewRequest

// case 2: server error
originDo := hookDo
defer func() { hookDo = originDo }()
hookDo = func(fn do) do {
return func(req *http.Request) (res *http.Response, err error) {
err = errors.New("mock server error")
return
}
}
_, err = Do(req)
assert.EqualError(t, err, "mock server error")

// case 4: mock read response error
hookDo = func(fn do) do {
return func(req *http.Request) (res *http.Response, err error) {
res = &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: map[string][]string{},
StatusCode: 200,
Status: "200 " + http.StatusText(200),
}
res.Body = ioutil.NopCloser(&errorReader{})
return
}
}

_, err = Do(req)
assert.EqualError(t, err, "read failed")
}

func TestDoWithProxy(t *testing.T) {
req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Form: map[string]string{
"URL": "HI",
},
Headers: map[string]string{
"Accept-Language": "zh",
},
Proxy: "http://localhost:9999/",
}
_, err := Do(req)
assert.Contains(t, err.Error(), "proxyconnect tcp: dial tcp")
assert.Contains(t, err.Error(), "connect: connection refused")

// invalid proxy url
req.Proxy = string([]byte{0x7f})
_, err = Do(req)
assert.Contains(t, err.Error(), "net/url: invalid control character in URL")
}

func TestDoWithConnectTimeout(t *testing.T) {
req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
Form: map[string]string{
"URL": "HI",
},
Headers: map[string]string{
"Accept-Language": "zh",
},
ConnectTimeout: 1 * time.Nanosecond,
}
_, err := Do(req)
assert.Contains(t, err.Error(), "dial tcp: ")
assert.Contains(t, err.Error(), "i/o timeout")
}

func TestDoWithReadTimeout(t *testing.T) {
req := &Request{
Method: "POST",
Protocol: "http",
Host: "www.aliyun.com",
Path: "/",
ReadTimeout: 1 * time.Nanosecond,
}
_, err := Do(req)
assert.Contains(t, err.Error(), "(Client.Timeout exceeded while awaiting headers)")
}
Loading