Skip to content

Commit

Permalink
cloud: add new Uploader interface and implement for AWS
Browse files Browse the repository at this point in the history
This commit adds a new `cloud.Uploader` interface that combines
the upload and register into a single operation. The rational
is that with that we avoid leaking resource if e.g. the upload
works but the registration fails.
  • Loading branch information
mvo5 committed Jan 31, 2025
1 parent 689348b commit 2aeafec
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 5 deletions.
16 changes: 11 additions & 5 deletions pkg/cloud/awscloud/awscloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ func WaitUntilImportSnapshotTaskCompletedWithContext(c *ec2.EC2, ctx aws.Context
// The caller can also specify the name of the role used to do the import.
// If nil is given, the default one from the SDK is used (vmimport).
// Returns the image ID and the snapshot ID.
//
// XXX: make this return (string, string, error) instead of pointers
func (a *AWS) Register(name, bucket, key string, shareWith []string, rpmArch string, bootMode, importRole *string) (*string, *string, error) {
rpmArchToEC2Arch := map[string]string{
"x86_64": "x86_64",
Expand Down Expand Up @@ -294,11 +296,7 @@ func (a *AWS) Register(name, bucket, key string, shareWith []string, rpmArch str

// we no longer need the object in s3, let's just delete it
logrus.Infof("[AWS] 🧹 Deleting image from S3: %s/%s", bucket, key)
_, err = a.s3.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
if err = a.DeleteObject(bucket, key); err != nil {
return nil, nil, err
}

Expand Down Expand Up @@ -388,6 +386,14 @@ func (a *AWS) Register(name, bucket, key string, shareWith []string, rpmArch str
return registerOutput.ImageId, snapshotID, nil
}

func (a *AWS) DeleteObject(bucket, key string) error {
_, err := a.s3.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
return err
}

// target region is determined by the region configured in the aws session
func (a *AWS) CopyImage(name, ami, sourceRegion string) (string, error) {
result, err := a.ec2.CopyImage(
Expand Down
11 changes: 11 additions & 0 deletions pkg/cloud/awscloud/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package awscloud

type AwsClient = awsClient

func MockNewAwsClient(f func(string) (awsClient, error)) (restore func()) {
saved := newAwsClient
newAwsClient = f
return func() {
newAwsClient = saved
}
}
134 changes: 134 additions & 0 deletions pkg/cloud/awscloud/uploader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package awscloud

import (
"errors"
"fmt"
"io"
"slices"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/google/uuid"

"github.com/osbuild/images/pkg/arch"
"github.com/osbuild/images/pkg/cloud"
)

type awsUploader struct {
client awsClient

region string
bucketName string
imageName string
targetArch string
}

type UploaderOptions struct {
TargetArch string
}

// testing support
type awsClient interface {
Regions() ([]string, error)
Buckets() ([]string, error)
CheckBucketPermission(string, S3Permission) (bool, error)
UploadFromReader(io.Reader, string, string) (*s3manager.UploadOutput, error)
Register(name, bucket, key string, shareWith []string, rpmArch string, bootMode, importRole *string) (*string, *string, error)
DeleteObject(string, string) error
}

var newAwsClient = func(region string) (awsClient, error) {
return NewDefault(region)
}

func NewUploader(region, bucketName, imageName string, opts *UploaderOptions) (cloud.Uploader, error) {
if opts == nil {
opts = &UploaderOptions{}
}
client, err := newAwsClient(region)
if err != nil {
return nil, err
}

return &awsUploader{
client: client,
region: region,
bucketName: bucketName,
imageName: imageName,
targetArch: opts.TargetArch,
}, nil
}

var _ cloud.Uploader = &awsUploader{}

func (au *awsUploader) Check(status io.Writer) error {
fmt.Fprintf(status, "Checking AWS region access...\n")
regions, err := au.client.Regions()
if err != nil {
return fmt.Errorf("retrieving AWS regions for '%s' failed: %w", au.region, err)
}

if !slices.Contains(regions, au.region) {
return fmt.Errorf("given AWS region '%s' not found", au.region)
}

fmt.Fprintf(status, "Checking AWS bucket...\n")
buckets, err := au.client.Buckets()
if err != nil {
return fmt.Errorf("retrieving AWS list of buckets failed: %w", err)
}
if !slices.Contains(buckets, au.bucketName) {
return fmt.Errorf("bucket '%s' not found in the given AWS account", au.bucketName)
}

fmt.Fprintf(status, "Checking AWS bucket permissions...\n")
writePermission, err := au.client.CheckBucketPermission(au.bucketName, S3PermissionWrite)
if err != nil {
return err
}
if !writePermission {
return fmt.Errorf("you don't have write permissions to bucket '%s' with the given AWS account", au.bucketName)
}
fmt.Fprintf(status, "Upload conditions met.\n")
return nil
}

func (au *awsUploader) UploadAndRegister(r io.Reader, status io.Writer) (err error) {
keyName := fmt.Sprintf("%s-%s", uuid.New().String(), au.imageName)
fmt.Fprintf(status, "Uploading %s to %s:%s\n", au.imageName, au.bucketName, keyName)

res, err := au.client.UploadFromReader(r, au.bucketName, keyName)
if err != nil {
return err
}
defer func() {
if err != nil {
aErr := au.client.DeleteObject(au.bucketName, keyName)
fmt.Fprintf(status, "Deleted S3 object %s:%s\n", au.bucketName, keyName)
err = errors.Join(err, aErr)
}
}()
fmt.Fprintf(status, "File uploaded to %s\n", aws.StringValue(&res.Location))
if au.targetArch == "" {
au.targetArch = arch.Current().String()
}
bootMode := ec2.BootModeValuesUefiPreferred

fmt.Fprintf(status, "Registering AMI %s\n", au.imageName)
ami, snapshot, err := au.client.Register(au.imageName, au.bucketName, keyName, nil, au.targetArch, &bootMode, nil)
if err != nil {
return err
}

fmt.Fprintf(status, "Deleted S3 object %s:%s\n", au.bucketName, keyName)
if err := au.client.DeleteObject(au.bucketName, keyName); err != nil {
return err
}
fmt.Fprintf(status, "AMI registered: %s\nSnapshot ID: %s\n", aws.StringValue(ami), aws.StringValue(snapshot))
if err != nil {
return err
}

return nil
}
195 changes: 195 additions & 0 deletions pkg/cloud/awscloud/uploader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package awscloud_test

import (
"bytes"
"fmt"
"io"
"testing"

"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"

"github.com/osbuild/images/pkg/cloud/awscloud"
)

// XXX: put into a new "cloudtest" package?
type fakeAWSClient struct {
regions []string
regionsErr error
regionsCalls int

buckets []string
bucketsErr error
bucketsCalls int

checkBucketPermission bool
checkBucketPermissionErr error
checkBucketPermissionCalls int

uploadFromReader *s3manager.UploadOutput
uploadFromReaderErr error
uploadFromReaderCalls int

registerErr error
registerImageId string
registerSnapshotId string
registerCalls int

deleteObjectErr error
deleteObjectCalls int
}

func (fa *fakeAWSClient) Regions() ([]string, error) {
fa.regionsCalls++
return fa.regions, fa.regionsErr
}

func (fa *fakeAWSClient) Buckets() ([]string, error) {
fa.bucketsCalls++
return fa.buckets, fa.bucketsErr
}

func (fa *fakeAWSClient) CheckBucketPermission(string, awscloud.S3Permission) (bool, error) {
fa.checkBucketPermissionCalls++
return fa.checkBucketPermission, fa.checkBucketPermissionErr
}

func (fa *fakeAWSClient) UploadFromReader(io.Reader, string, string) (*s3manager.UploadOutput, error) {
fa.uploadFromReaderCalls++
return fa.uploadFromReader, fa.uploadFromReaderErr
}

func (fa *fakeAWSClient) Register(name, bucket, key string, shareWith []string, rpmArch string, bootMode, importRole *string) (*string, *string, error) {
fa.registerCalls++
return &fa.registerImageId, &fa.registerSnapshotId, fa.registerErr
}

func (fa *fakeAWSClient) DeleteObject(string, string) error {
fa.deleteObjectCalls++
return fa.deleteObjectErr
}

func TestUploaderCheckHappy(t *testing.T) {
fa := &fakeAWSClient{
regions: []string{"region"},
buckets: []string{"bucket"},
checkBucketPermission: true,
}
restore := awscloud.MockNewAwsClient(func(string) (awscloud.AwsClient, error) {
return fa, nil
})
defer restore()

uploader, err := awscloud.NewUploader("region", "bucket", "ami", nil)
assert.NoError(t, err)
var statusLog bytes.Buffer
err = uploader.Check(&statusLog)
assert.NoError(t, err)
assert.Equal(t, 1, fa.regionsCalls)
assert.Equal(t, 1, fa.bucketsCalls)
assert.Equal(t, 1, fa.checkBucketPermissionCalls)
expectedStatusLog := `Checking AWS region access...
Checking AWS bucket...
Checking AWS bucket permissions...
Upload conditions met.
`
assert.Equal(t, expectedStatusLog, statusLog.String())
}

type repeatReader struct{}

func (r *repeatReader) Read(p []byte) (int, error) {
for i := range p {
p[i] = 0x1
}
return len(p), nil
}

func TestUploaderUploadHappy(t *testing.T) {
uuid.SetRand(&repeatReader{})

fa := &fakeAWSClient{
uploadFromReader: &s3manager.UploadOutput{
Location: "some-location",
},
registerImageId: "image-id",
registerSnapshotId: "snapshot-id",
}
restore := awscloud.MockNewAwsClient(func(string) (awscloud.AwsClient, error) {
return fa, nil
})
defer restore()

fakeImage := bytes.NewBufferString("fake-aws-image")
uploader, err := awscloud.NewUploader("region", "bucket", "ami", nil)
assert.NoError(t, err)
var uploadLog bytes.Buffer
err = uploader.UploadAndRegister(fakeImage, &uploadLog)
assert.NoError(t, err)
assert.Equal(t, 1, fa.uploadFromReaderCalls)
assert.Equal(t, 1, fa.registerCalls)
assert.Equal(t, 1, fa.deleteObjectCalls)
expectedUploadLog := `Uploading ami to bucket:01010101-0101-4101-8101-010101010101-ami
File uploaded to some-location
Registering AMI ami
Deleted S3 object bucket:01010101-0101-4101-8101-010101010101-ami
AMI registered: image-id
Snapshot ID: snapshot-id
`
assert.Equal(t, expectedUploadLog, uploadLog.String())
}

func TestUploaderUploadButRegisterError(t *testing.T) {
uuid.SetRand(&repeatReader{})

fa := &fakeAWSClient{
uploadFromReader: &s3manager.UploadOutput{
Location: "some-location",
},
registerErr: fmt.Errorf("fake-register-err"),
}
restore := awscloud.MockNewAwsClient(func(string) (awscloud.AwsClient, error) {
return fa, nil
})
defer restore()

fakeImage := bytes.NewBufferString("fake-aws-image")
uploader, err := awscloud.NewUploader("region", "bucket", "ami", nil)
assert.NoError(t, err)
var uploadLog bytes.Buffer
err = uploader.UploadAndRegister(fakeImage, &uploadLog)
// XXX: this should probably have a context
assert.EqualError(t, err, "fake-register-err")
assert.Equal(t, 1, fa.uploadFromReaderCalls)
assert.Equal(t, 1, fa.registerCalls)
assert.Equal(t, 1, fa.deleteObjectCalls)
expectedUploadLog := `Uploading ami to bucket:01010101-0101-4101-8101-010101010101-ami
File uploaded to some-location
Registering AMI ami
Deleted S3 object bucket:01010101-0101-4101-8101-010101010101-ami
`
assert.Equal(t, expectedUploadLog, uploadLog.String())
}

func TestUploaderUploadButRegisterErrorAndDeleteError(t *testing.T) {
fa := &fakeAWSClient{
uploadFromReader: &s3manager.UploadOutput{
Location: "some-location",
},
registerErr: fmt.Errorf("fake-register-err"),
deleteObjectErr: fmt.Errorf("fake-delete-object-err"),
}
restore := awscloud.MockNewAwsClient(func(string) (awscloud.AwsClient, error) {
return fa, nil
})
defer restore()

fakeImage := bytes.NewBufferString("fake-aws-image")
uploader, err := awscloud.NewUploader("region", "bucket", "ami", nil)
assert.NoError(t, err)
var uploadLog bytes.Buffer
err = uploader.UploadAndRegister(fakeImage, &uploadLog)
// XXX: this should probably have a context
assert.EqualError(t, err, "fake-register-err\nfake-delete-object-err")
}
Loading

0 comments on commit 2aeafec

Please sign in to comment.