Skip to content

Commit

Permalink
cache
Browse files Browse the repository at this point in the history
  • Loading branch information
ettec committed Feb 5, 2025
1 parent f973d5a commit 409f915
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 104 deletions.
117 changes: 77 additions & 40 deletions pkg/workflows/wasm/host/module.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package host

import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -102,13 +105,53 @@ type ModuleConfig struct {
Determinism *DeterminismConfig
}

type WasmtimeModuleFactoryFn func(engine *wasmtime.Engine, binary []byte, isUncompressed bool, maxCompressedBinarySize uint64, maxDecompressedBinarySize uint64) (*wasmtime.Module, error)
type BinaryID string

// FromBinary converts a given Binary to its ID.
func FromBinary(binary []byte) BinaryID {
sha := sha256.Sum256(binary)
return BinaryID(hex.EncodeToString(sha[:]))
}

type WasmTimeModule struct {
engine *wasmtime.Engine
module *wasmtime.Module
config *wasmtime.Config
}

func NewWasmTimeModule(engine *wasmtime.Engine, module *wasmtime.Module, config *wasmtime.Config) *WasmTimeModule {
return &WasmTimeModule{engine: engine, module: module, config: config}
}

// Close closes the module and its dependencies.
func (w *WasmTimeModule) Close() error {
w.engine.Close()
w.module.Close()
w.config.Close()
return nil
}

func ModuleEquals(this *WasmTimeModule, that *WasmTimeModule) (bool, error) {
thisBytes, err := this.module.Serialize()
if err != nil {
return false, err
}

thatBytes, err := that.module.Serialize()
if err != nil {
return false, err
}

return bytes.Equal(thisBytes, thatBytes), nil
}

// WasmtimeModuleFactoryFn is a function that creates a wasmtime.Module from a given serialise wasm binary.
type WasmtimeModuleFactoryFn func(initialFuel uint64) (*WasmTimeModule, error)

type Module struct {
engine *wasmtime.Engine
module *wasmtime.Module
linker *wasmtime.Linker
wconfig *wasmtime.Config
wasmtimeModule *WasmTimeModule

linker *wasmtime.Linker

requestStore *store

Expand All @@ -132,7 +175,7 @@ func WithDeterminism() func(*ModuleConfig) {
}
}

func NewModule(modCfg *ModuleConfig, binary []byte, wasmtimeModuleFactory WasmtimeModuleFactoryFn,
func NewModule(modCfg *ModuleConfig, wasmtimeModuleFactory WasmtimeModuleFactoryFn,
opts ...func(*ModuleConfig)) (*Module, error) {
// Apply options to the module config.
for _, opt := range opts {
Expand Down Expand Up @@ -171,40 +214,18 @@ func NewModule(modCfg *ModuleConfig, binary []byte, wasmtimeModuleFactory Wasmti
modCfg.MinMemoryMBs = defaultMinMemoryMBs
}

if modCfg.MaxCompressedBinarySize == 0 {
modCfg.MaxCompressedBinarySize = uint64(defaultMaxCompressedBinarySize)
}

if modCfg.MaxDecompressedBinarySize == 0 {
modCfg.MaxDecompressedBinarySize = uint64(defaultMaxDecompressedBinarySize)
}

// Take the max of the min and the configured max memory mbs.
// We do this because Go requires a minimum of 16 megabytes to run,
// and local testing has shown that with less than the min, some
// binaries may error sporadically.
modCfg.MaxMemoryMBs = uint64(math.Max(float64(modCfg.MinMemoryMBs), float64(modCfg.MaxMemoryMBs)))

cfg := wasmtime.NewConfig()
cfg.SetEpochInterruption(true)
if modCfg.InitialFuel > 0 {
cfg.SetConsumeFuel(true)
}

cfg.CacheConfigLoadDefault()
cfg.SetCraneliftOptLevel(wasmtime.OptLevelSpeedAndSize)

// Load testing shows that leaving native unwind info enabled causes a very large slowdown when loading multiple modules.
cfg.SetNativeUnwindInfo(false)

engine := wasmtime.NewEngineWithConfig(cfg)

mod, err := wasmtimeModuleFactory(engine, binary, modCfg.IsUncompressed, modCfg.MaxCompressedBinarySize, modCfg.MaxDecompressedBinarySize)
wasmModule, err := wasmtimeModuleFactory(modCfg.InitialFuel)
if err != nil {
return nil, fmt.Errorf("error creating wasmtime module: %w", err)
}

linker, err := newWasiLinker(modCfg, engine)
linker, err := newWasiLinker(modCfg, wasmModule.engine)
if err != nil {
return nil, fmt.Errorf("error creating wasi linker: %w", err)
}
Expand Down Expand Up @@ -250,10 +271,8 @@ func NewModule(modCfg *ModuleConfig, binary []byte, wasmtimeModuleFactory Wasmti
}

m := &Module{
engine: engine,
module: mod,
linker: linker,
wconfig: cfg,
wasmtimeModule: wasmModule,
linker: linker,

requestStore: requestStore,

Expand All @@ -265,6 +284,26 @@ func NewModule(modCfg *ModuleConfig, binary []byte, wasmtimeModuleFactory Wasmti
return m, nil
}

func GetEngineConfiguration(initialFuel uint64) (*wasmtime.Config, error) {
cfg := wasmtime.NewConfig()
cfg.SetEpochInterruption(true)
if initialFuel > 0 {
cfg.SetConsumeFuel(true)
}

err := cfg.CacheConfigLoadDefault()
if err != nil {
return nil, fmt.Errorf("error loading default cache config: %w", err)
}

cfg.SetCraneliftOptLevel(wasmtime.OptLevelSpeedAndSize)

// Load testing shows that leaving native unwind info enabled causes a very large slowdown when loading multiple modules.
cfg.SetNativeUnwindInfo(false)

return cfg, nil
}

func (m *Module) Start() {
m.wg.Add(1)
go func() {
Expand All @@ -276,7 +315,7 @@ func (m *Module) Start() {
case <-m.stopCh:
return
case <-ticker.C:
m.engine.IncrementEpoch()
m.wasmtimeModule.engine.IncrementEpoch()
}
}
}()
Expand All @@ -287,9 +326,7 @@ func (m *Module) Close() {
m.wg.Wait()

m.linker.Close()
m.engine.Close()
m.module.Close()
m.wconfig.Close()
m.wasmtimeModule.Close()
}

func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Response, error) {
Expand All @@ -312,7 +349,7 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp
// we delete the request data from the store when we're done
defer m.requestStore.delete(request.Id)

store := wasmtime.NewStore(m.engine)
store := wasmtime.NewStore(m.wasmtimeModule.engine)
defer store.Close()

reqpb, err := proto.Marshal(request)
Expand Down Expand Up @@ -348,7 +385,7 @@ func (m *Module) Run(ctx context.Context, request *wasmpb.Request) (*wasmpb.Resp
deadline := *m.cfg.Timeout / m.cfg.TickInterval
store.SetEpochDeadline(uint64(deadline))

instance, err := m.linker.Instantiate(store, m.module)
instance, err := m.linker.Instantiate(store, m.wasmtimeModule.module)
if err != nil {
return nil, err
}
Expand Down
42 changes: 40 additions & 2 deletions pkg/workflows/wasm/host/wasm.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package host

import (
"bytes"
"context"
"errors"
"fmt"
"io"

"github.com/andybalholm/brotli"
"github.com/google/uuid"

"google.golang.org/protobuf/types/known/emptypb"
Expand All @@ -13,9 +16,9 @@ import (
wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb"
)

func GetWorkflowSpec(ctx context.Context, modCfg *ModuleConfig, binary []byte, config []byte,
func GetWorkflowSpec(ctx context.Context, modCfg *ModuleConfig, config []byte,
wasmtimeModuleFactory WasmtimeModuleFactoryFn) (*sdk.WorkflowSpec, error) {
m, err := NewModule(modCfg, binary, wasmtimeModuleFactory, WithDeterminism())
m, err := NewModule(modCfg, wasmtimeModuleFactory, WithDeterminism())
if err != nil {
return nil, fmt.Errorf("could not instantiate module: %w", err)
}
Expand Down Expand Up @@ -44,3 +47,38 @@ func GetWorkflowSpec(ctx context.Context, modCfg *ModuleConfig, binary []byte, c

return wasmpb.ProtoToWorkflowSpec(sr)
}

func ValidateAndDecompressBinary(binary []byte, isUncompressed bool, maxCompressedBinarySize uint64, maxDecompressedBinarySize uint64) ([]byte, error) {

if maxCompressedBinarySize == 0 {
maxCompressedBinarySize = uint64(defaultMaxCompressedBinarySize)
}

if maxDecompressedBinarySize == 0 {
maxDecompressedBinarySize = uint64(defaultMaxDecompressedBinarySize)
}

if !isUncompressed {
// validate the binary size before decompressing
// this is to prevent decompression bombs
if uint64(len(binary)) > maxCompressedBinarySize {
return nil, fmt.Errorf("compressed binary size exceeds the maximum allowed size of %d bytes", maxCompressedBinarySize)
}

rdr := io.LimitReader(brotli.NewReader(bytes.NewBuffer(binary)), int64(maxDecompressedBinarySize+1))
decompedBinary, err := io.ReadAll(rdr)
if err != nil {
return nil, fmt.Errorf("failed to decompress binary: %w", err)
}

binary = decompedBinary
}

// Validate the decompressed binary size.
// io.LimitReader prevents decompression bombs by reading up to a set limit, but it will not return an error if the limit is reached.
// The Read() method will return io.EOF, and ReadAll will gracefully handle it and return nil.
if uint64(len(binary)) > maxDecompressedBinarySize {
return nil, fmt.Errorf("decompressed binary size reached the maximum allowed size of %d bytes", maxDecompressedBinarySize)
}
return binary, nil
}
Loading

0 comments on commit 409f915

Please sign in to comment.