Skip to content

Commit

Permalink
feat: validate decompresed binary size
Browse files Browse the repository at this point in the history
  • Loading branch information
agparadiso committed Jan 29, 2025
1 parent bcca537 commit 47ea589
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
50 changes: 32 additions & 18 deletions pkg/workflows/wasm/host/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,31 @@ func (r *store) delete(id string) {
}

var (
defaultTickInterval = 100 * time.Millisecond
defaultTimeout = 10 * time.Second
defaultMinMemoryMBs = uint64(128)
DefaultInitialFuel = uint64(100_000_000)
defaultMaxFetchRequests = 5
defaultMaxCompressedBinarySize = 10 * 1024 * 1024 // 10 MB
defaultTickInterval = 100 * time.Millisecond
defaultTimeout = 10 * time.Second
defaultMinMemoryMBs = uint64(128)
DefaultInitialFuel = uint64(100_000_000)
defaultMaxFetchRequests = 5
defaultMaxCompressedBinarySize = 10 * 1024 * 1024 // 10 MB
defaultMaxDecompressedBinarySize = 100 * 1024 * 1024 // 100 MB
)

type DeterminismConfig struct {
// Seed is the seed used to generate cryptographically insecure random numbers in the module.
Seed int64
}
type ModuleConfig struct {
TickInterval time.Duration
Timeout *time.Duration
MaxMemoryMBs uint64
MinMemoryMBs uint64
InitialFuel uint64
Logger logger.Logger
IsUncompressed bool
Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error)
MaxFetchRequests int
MaxCompressedBinarySize uint64
TickInterval time.Duration
Timeout *time.Duration
MaxMemoryMBs uint64
MinMemoryMBs uint64
InitialFuel uint64
Logger logger.Logger
IsUncompressed bool
Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error)
MaxFetchRequests int
MaxCompressedBinarySize uint64
MaxDecompressedBinarySize uint64

// Labeler is used to emit messages from the module.
Labeler custmsg.MessageEmitter
Expand Down Expand Up @@ -173,6 +175,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig))
modCfg.MaxCompressedBinarySize = uint64(defaultMaxCompressedBinarySize)
}

if modCfg.MaxDecompressedBinarySize == 0 {
modCfg.MaxDecompressedBinarySize = uint64(defaultMaxDecompressedBinarySize)
}

// Take the max of the min and the configured max memory mbs.
// We do this because Go requires a minimum of 16 megabytes to run,
// and local testing has shown that with less than the min, some
Expand All @@ -196,10 +202,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig))
// validate the binary size before decompressing
// this is to prevent decompression bombs
if uint64(len(binary)) > modCfg.MaxCompressedBinarySize {
return nil, fmt.Errorf("binary size exceeds the maximum allowed size of %d bytes", modCfg.MaxCompressedBinarySize)
return nil, fmt.Errorf("compressed binary size exceeds the maximum allowed size of %d bytes", modCfg.MaxCompressedBinarySize)
}

rdr := brotli.NewReader(bytes.NewBuffer(binary))
rdr := io.LimitReader(brotli.NewReader(bytes.NewBuffer(binary)), int64(modCfg.MaxDecompressedBinarySize))
decompedBinary, err := io.ReadAll(rdr)
if err != nil {
return nil, fmt.Errorf("failed to decompress binary: %w", err)
Expand All @@ -208,6 +214,14 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig))
binary = decompedBinary
}

// Validate the decompressed binary size.
// io.LimitReader prevents decompression bombs by reading up to a set limit, but it will not return an error if the limit is reached.
// The Read() method will return io.EOF, and ReadAll will gracefully handle it and return nil.
// Because of this, we treat the limit as a non-inclusive limit. If the limit is reached, we return an error.
if uint64(len(binary)) >= modCfg.MaxDecompressedBinarySize {
return nil, fmt.Errorf("decompressed binary size reached the maximum allowed size of %d bytes", modCfg.MaxDecompressedBinarySize)
}

mod, err := wasmtime.NewModule(engine, binary)
if err != nil {
return nil, fmt.Errorf("error creating wasmtime module: %w", err)
Expand Down
20 changes: 20 additions & 0 deletions pkg/workflows/wasm/host/wasm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,26 @@ func TestModule_CompressedBinarySize(t *testing.T) {
})
}

func TestModule_DecompressedBinarySize(t *testing.T) {
t.Parallel()

// compressed binary size is 4.121 MB
// decompressed binary size is 23.7 MB
binary := createTestBinary(successBinaryCmd, successBinaryLocation, false, t)
t.Run("decompressed binary size is within the limit", func(t *testing.T) {
customDecompressedBinarySize := uint64(24 * 1024 * 1024)
_, err := NewModule(&ModuleConfig{IsUncompressed: false, MaxDecompressedBinarySize: customDecompressedBinarySize, Logger: logger.Test(t)}, binary)
require.NoError(t, err)
})

t.Run("decompressed binary size is bigger than the limit", func(t *testing.T) {
customDecompressedBinarySize := uint64(3 * 1024 * 1024)
_, err := NewModule(&ModuleConfig{IsUncompressed: false, MaxDecompressedBinarySize: customDecompressedBinarySize, Logger: logger.Test(t)}, binary)
decompressedSizeExceeded := fmt.Sprintf("decompressed binary size reached the maximum allowed size of %d bytes", customDecompressedBinarySize)
require.ErrorContains(t, err, decompressedSizeExceeded)
})
}

func TestModule_Sandbox_SleepIsStubbedOut(t *testing.T) {
t.Parallel()
ctx := tests.Context(t)
Expand Down

0 comments on commit 47ea589

Please sign in to comment.