Skip to content

Commit

Permalink
s3: move to aws-sdk-go-v2 package
Browse files Browse the repository at this point in the history
Signed-off-by: Janusz Marcinkiewicz <[email protected]>
  • Loading branch information
VirrageS committed Feb 9, 2024
1 parent 3e71357 commit b094fe8
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 148 deletions.
155 changes: 77 additions & 78 deletions ais/backend/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ import (
"github.com/NVIDIA/aistore/cmn/nlog"
"github.com/NVIDIA/aistore/core"
"github.com/NVIDIA/aistore/core/meta"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/aws/aws-sdk-go-v2/aws"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/config"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go"
)

const (
Expand All @@ -54,7 +55,7 @@ type (
)

var (
clients map[string]*s3.S3 // one s3.Client aka "svc" per (profile, region, endpoint) triplet
clients map[string]*s3.Client // one s3.Client aka "svc" per (profile, region, endpoint) triplet
cmu sync.Mutex
s3Endpoint string
awsProfile string
Expand All @@ -64,7 +65,7 @@ var (
var _ core.BackendProvider = (*awsProvider)(nil)

func NewAWS(t core.TargetPut) (core.BackendProvider, error) {
clients = make(map[string]*s3.S3, 2)
clients = make(map[string]*s3.Client, 2)
s3Endpoint = os.Getenv(awsEnvS3Endpoint)
awsProfile = os.Getenv(awsEnvConfigProfile)
return &awsProvider{t: t}, nil
Expand All @@ -91,7 +92,7 @@ func (*awsProvider) CreateBucket(_ *meta.Bck) (int, error) {

func (*awsProvider) HeadBucket(_ ctx, bck *meta.Bck) (bckProps cos.StrKVs, errCode int, err error) {
var (
svc *s3.S3
svc *s3.Client
region string
errC error
cloudBck = bck.RemoteBck()
Expand All @@ -113,7 +114,7 @@ func (*awsProvider) HeadBucket(_ ctx, bck *meta.Bck) (bckProps cos.StrKVs, errCo
return
}
}
region = *svc.Config.Region
region = svc.Options().Region
debug.Assert(region != "")

// NOTE: return a few assorted fields, specifically to fill-in vendor-specific `cmn.ExtraProps`
Expand Down Expand Up @@ -142,7 +143,7 @@ const versionedPageSize = 20

func (awsp *awsProvider) ListObjects(bck *meta.Bck, msg *apc.LsoMsg, lst *cmn.LsoResult) (errCode int, err error) {
var (
svc *s3.S3
svc *s3.Client
h = cmn.BackendHelpers.Amazon
cloudBck = bck.RemoteBck()
versioning bool
Expand All @@ -165,9 +166,9 @@ func (awsp *awsProvider) ListObjects(bck *meta.Bck, msg *apc.LsoMsg, lst *cmn.Ls
if versioning {
msg.PageSize = min(versionedPageSize, msg.PageSize)
}
params.MaxKeys = aws.Int64(int64(msg.PageSize))
params.MaxKeys = aws.Int32(int32(msg.PageSize))

resp, err := svc.ListObjectsV2(params)
resp, err := svc.ListObjectsV2(context.Background(), params)
if err != nil {
if cmn.Rom.FastV(4, cos.SmoduleBackend) {
nlog.Infoln("list_objects", cloudBck.Name, err)
Expand Down Expand Up @@ -222,7 +223,7 @@ func (awsp *awsProvider) ListObjects(bck *meta.Bck, msg *apc.LsoMsg, lst *cmn.Ls
)
for _, entry := range lst.Entries {
verParams.Prefix = aws.String(entry.Name)
verResp, err := svc.ListObjectVersions(verParams)
verResp, err := svc.ListObjectVersions(context.Background(), verParams)
if err != nil {
return awsErrorToAISError(err, cloudBck, "")
}
Expand Down Expand Up @@ -254,7 +255,7 @@ func (*awsProvider) ListBuckets(cmn.QueryBcks) (bcks cmn.Bcks, errCode int, err
errCode, err = awsErrorToAISError(err, &cmn.Bck{Provider: apc.AWS}, "")
return
}
result, err := svc.ListBuckets(&s3.ListBucketsInput{})
result, err := svc.ListBuckets(context.Background(), &s3.ListBucketsInput{})
if err != nil {
errCode, err = awsErrorToAISError(err, &cmn.Bck{Provider: apc.AWS}, "")
return
Expand All @@ -263,10 +264,10 @@ func (*awsProvider) ListBuckets(cmn.QueryBcks) (bcks cmn.Bcks, errCode int, err
bcks = make(cmn.Bcks, len(result.Buckets))
for idx, bck := range result.Buckets {
if cmn.Rom.FastV(4, cos.SmoduleBackend) {
nlog.Infoln("[bucket_names]", aws.StringValue(bck.Name), "created", *bck.CreationDate)
nlog.Infoln("[bucket_names]", aws.ToString(bck.Name), "created", *bck.CreationDate)
}
bcks[idx] = cmn.Bck{
Name: aws.StringValue(bck.Name),
Name: aws.ToString(bck.Name),
Provider: apc.AWS,
}
}
Expand All @@ -279,16 +280,16 @@ func (*awsProvider) ListBuckets(cmn.QueryBcks) (bcks cmn.Bcks, errCode int, err

func (*awsProvider) HeadObj(_ ctx, lom *core.LOM) (oa *cmn.ObjAttrs, errCode int, err error) {
var (
svc *s3.Client
headOutput *s3.HeadObjectOutput
svc *s3.S3
h = cmn.BackendHelpers.Amazon
cloudBck = lom.Bck().RemoteBck()
)
svc, _, err = newClient(sessConf{bck: cloudBck}, "[head_object]")
if err != nil && cmn.Rom.FastV(4, cos.SmoduleBackend) {
nlog.Warningln(err)
}
headOutput, err = svc.HeadObject(&s3.HeadObjectInput{
headOutput, err = svc.HeadObject(context.Background(), &s3.HeadObjectInput{
Bucket: aws.String(cloudBck.Name),
Key: aws.String(lom.ObjName),
})
Expand Down Expand Up @@ -321,7 +322,7 @@ func (*awsProvider) HeadObj(_ ctx, lom *core.LOM) (oa *cmn.ObjAttrs, errCode int
md := headOutput.Metadata
if cksumType, ok := md[cos.S3MetadataChecksumType]; ok {
if cksumValue, ok := md[cos.S3MetadataChecksumVal]; ok {
oa.SetCksum(*cksumType, *cksumValue)
oa.SetCksum(cksumType, cksumValue)
}
}

Expand Down Expand Up @@ -377,13 +378,13 @@ func (*awsProvider) GetObjReader(ctx context.Context, lom *core.LOM, offset, len
if length > 0 {
rng := cmn.MakeRangeHdr(offset, length)
input.Range = aws.String(rng)
obj, err = svc.GetObjectWithContext(ctx, &input)
obj, err = svc.GetObject(ctx, &input)
if err != nil {
res.ErrCode, res.Err = awsErrorToAISError(err, cloudBck, lom.ObjName)
return
}
} else {
obj, err = svc.GetObjectWithContext(ctx, &input)
obj, err = svc.GetObject(ctx, &input)
if err != nil {
res.ErrCode, res.Err = awsErrorToAISError(err, cloudBck, lom.ObjName)
return
Expand All @@ -396,7 +397,7 @@ func (*awsProvider) GetObjReader(ctx context.Context, lom *core.LOM, offset, len
md := obj.Metadata
if cksumType, ok := md[cos.S3MetadataChecksumType]; ok {
if cksumValue, ok := md[cos.S3MetadataChecksumVal]; ok {
cksum := cos.NewCksum(*cksumType, *cksumValue)
cksum := cos.NewCksum(cksumType, cksumValue)
lom.SetCksum(cksum)
res.ExpCksum = cksum // precedence over md5 (<= ETag)
}
Expand Down Expand Up @@ -433,24 +434,24 @@ func _getCustom(lom *core.LOM, obj *s3.GetObjectOutput) (md5 *cos.Cksum) {

func (*awsProvider) PutObj(r io.ReadCloser, lom *core.LOM) (errCode int, err error) {
var (
svc *s3.S3
svc *s3.Client
uploadOutput *s3manager.UploadOutput
h = cmn.BackendHelpers.Amazon
cksumType, cksumValue = lom.Checksum().Get()
cloudBck = lom.Bck().RemoteBck()
md = make(map[string]*string, 2)
md = make(map[string]string, 2)
)

svc, _, err = newClient(sessConf{bck: cloudBck}, "[put_object]")
if err != nil && cmn.Rom.FastV(5, cos.SmoduleBackend) {
nlog.Warningln(err)
}

md[cos.S3MetadataChecksumType] = aws.String(cksumType)
md[cos.S3MetadataChecksumVal] = aws.String(cksumValue)
md[cos.S3MetadataChecksumType] = cksumType
md[cos.S3MetadataChecksumVal] = cksumValue

uploader := s3manager.NewUploaderWithClient(svc)
uploadOutput, err = uploader.Upload(&s3manager.UploadInput{
uploader := s3manager.NewUploader(svc)
uploadOutput, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String(cloudBck.Name),
Key: aws.String(lom.ObjName),
Body: r,
Expand Down Expand Up @@ -486,14 +487,14 @@ func (*awsProvider) PutObj(r io.ReadCloser, lom *core.LOM) (errCode int, err err

func (*awsProvider) DeleteObj(lom *core.LOM) (errCode int, err error) {
var (
svc *s3.S3
svc *s3.Client
cloudBck = lom.Bck().RemoteBck()
)
svc, _, err = newClient(sessConf{bck: cloudBck}, "[delete_object]")
if err != nil && cmn.Rom.FastV(4, cos.SmoduleBackend) {
nlog.Warningln(err)
}
_, err = svc.DeleteObject(&s3.DeleteObjectInput{
_, err = svc.DeleteObject(context.Background(), &s3.DeleteObjectInput{
Bucket: aws.String(cloudBck.Name),
Key: aws.String(lom.ObjName),
})
Expand All @@ -518,7 +519,7 @@ func (*awsProvider) DeleteObj(lom *core.LOM) (errCode int, err error) {
// From S3 SDK:
// "S3 methods are safe to use concurrently. It is not safe to modify mutate
// any of the struct's properties though."
func newClient(conf sessConf, tag string) (svc *s3.S3, region string, err error) {
func newClient(conf sessConf, tag string) (svc *s3.Client, region string, err error) {
var (
endpoint = s3Endpoint
profile = awsProfile
Expand Down Expand Up @@ -547,18 +548,21 @@ func newClient(conf sessConf, tag string) (svc *s3.S3, region string, err error)
}

// slow path
sess, config := _session(endpoint, profile)
cfg, err := loadConfig(endpoint, profile)
if err != nil {
return nil, "", err
}
if region == "" {
if tag != "" {
err = fmt.Errorf("%s: unknown region for bucket %s -- proceeding with default", tag, conf.bck)
}
svc = s3.New(sess)
svc = s3.NewFromConfig(cfg)
return
}
// have region
config.Region = aws.String(region)
svc = s3.New(sess, config)
debug.Assertf(region == *svc.Config.Region, "%s != %s", region, *svc.Config.Region)
svc = s3.NewFromConfig(cfg, func(options *s3.Options) {
options.Region = region
})
debug.Assertf(region == svc.Options().Region, "%s != %s", region, svc.Options().Region)

cmu.Lock()
clients[cid] = svc
Expand All @@ -582,71 +586,66 @@ func _cid(profile, region, endpoint string) string {
return sb.String()
}

// Create session using default creds from ~/.aws/credentials and environment variables.
func _session(endpoint, profile string) (*session.Session, *aws.Config) {
config := aws.Config{
HTTPClient: cmn.NewClient(cmn.TransportArgs{}),
LowerCaseHeaderMaps: apc.Bool(true),
// loadConfig create config using default creds from ~/.aws/credentials and environment variables.
func loadConfig(endpoint, profile string) (aws.Config, error) {
// NOTE: The AWS SDK for Go v2, uses lower case header maps by default.
cfg, err := config.LoadDefaultConfig(
context.Background(),
config.WithHTTPClient(cmn.NewClient(cmn.TransportArgs{})),
config.WithSharedConfigProfile(profile),
)
if err != nil {
return cfg, err
}
// `endpoint` is normally empty but could also be `Props.Extra.AWS.Endpoint` or `os.Getenv(awsEnvS3Endpoint)`
// (with bucket-specific `Props` taking precedence)
config.WithEndpoint(endpoint)

opts := session.Options{
SharedConfigState: session.SharedConfigEnable,
Config: config,
Profile: profile,
if endpoint != "" {
cfg.BaseEndpoint = aws.String(endpoint)
}
return session.Must(session.NewSessionWithOptions(opts)), &config
return cfg, nil
}

func getBucketVersioning(svc *s3.S3, bck *cmn.Bck) (enabled bool, errV error) {
func getBucketVersioning(svc *s3.Client, bck *cmn.Bck) (enabled bool, errV error) {
input := &s3.GetBucketVersioningInput{Bucket: aws.String(bck.Name)}
result, err := svc.GetBucketVersioning(input)
result, err := svc.GetBucketVersioning(context.Background(), input)
if err != nil {
return false, err
}
enabled = result.Status != nil && *result.Status == s3.BucketVersioningStatusEnabled
enabled = result.Status == types.BucketVersioningStatusEnabled
return
}

func getBucketLocation(svc *s3.S3, bckName string) (region string, err error) {
resp, err := svc.GetBucketLocation(&s3.GetBucketLocationInput{
func getBucketLocation(svc *s3.Client, bckName string) (region string, err error) {
resp, err := svc.GetBucketLocation(context.Background(), &s3.GetBucketLocationInput{
Bucket: aws.String(bckName),
})
if err != nil {
return
}
region = aws.StringValue(resp.LocationConstraint)

// NOTE: AWS API returns empty region "only" for 'us-east-1`
region = string(resp.LocationConstraint)
if region == "" {
region = endpoints.UsEast1RegionID
region = "us-east-1" // Buckets in region `us-east-1` have a LocationConstraint of null.
}
return
}

// For reference see https://github.com/aws/aws-sdk-go-v2/issues/1110#issuecomment-1054643716.
func awsErrorToAISError(awsError error, bck *cmn.Bck, objName string) (int, error) {
reqErr, ok := awsError.(awserr.RequestFailure)
if !ok {
var reqErr smithy.APIError
if !errors.As(awsError, &reqErr) {
return http.StatusInternalServerError, _awsErr(awsError)
}
awsCode, status := reqErr.Code(), reqErr.StatusCode()
switch {
case awsCode == s3.ErrCodeNoSuchBucket:
return status, cmn.NewErrRemoteBckNotFound(bck)
case awsCode == s3.ErrCodeNoSuchKey || (status == http.StatusNotFound && objName != ""):
debug.Assert(status == http.StatusNotFound, status) // expected
return status, errors.New("aws-error[NotFound: " + bck.Cname(objName) + "]")
case status == http.StatusForbidden && strings.Contains(awsCode, "AllAccessDisabled"):
// HACK: "not found or misspelled" vs "service not paid for" (the latter less likely)
err := _awsErr(awsError)
if cmn.Rom.FastV(4, cos.SmoduleBackend) {
nlog.Infoln(err)
}
return http.StatusNotFound, err

switch reqErr.(type) {
case *types.NoSuchBucket:
return http.StatusNotFound, cmn.NewErrRemoteBckNotFound(bck)
case *types.NoSuchKey:
return http.StatusNotFound, errors.New("aws-error[NotFound: " + bck.Cname(objName) + "]")
default:
return status, _awsErr(awsError)
var httpResponseErr *awshttp.ResponseError
if errors.As(awsError, &httpResponseErr) {
return httpResponseErr.HTTPStatusCode(), _awsErr(awsError)
}

return http.StatusBadRequest, _awsErr(awsError)
}
}

Expand Down
Loading

0 comments on commit b094fe8

Please sign in to comment.