diff --git a/commit/plugin_e2e_test.go b/commit/plugin_e2e_test.go index 26c32cc2f..7bc20e70b 100644 --- a/commit/plugin_e2e_test.go +++ b/commit/plugin_e2e_test.go @@ -736,7 +736,7 @@ func prepareCcipReaderMock( } if enableDiscovery { - ccipReader.EXPECT().DiscoverContracts(mock.Anything).Return(nil, nil) + ccipReader.EXPECT().DiscoverContracts(mock.Anything, mock.Anything).Return(nil, nil) ccipReader.EXPECT().Sync(mock.Anything, mock.Anything).Return(nil) } } @@ -801,8 +801,10 @@ func setupNode(params SetupNodeParams) nodeSetup { } homeChainReader.EXPECT().GetFChain().Return(fChain, nil) - homeChainReader.EXPECT(). - GetOCRConfigs(mock.Anything, params.donID, consts.PluginTypeCommit). + if params.enableDiscovery { + homeChainReader.EXPECT().GetAllChainConfigs().Return(params.chainCfg, nil) + } + homeChainReader.EXPECT().GetOCRConfigs(mock.Anything, params.donID, consts.PluginTypeCommit). Return(reader.ActiveAndCandidate{ ActiveConfig: reader.OCR3ConfigWithMeta{ ConfigDigest: params.reportingCfg.ConfigDigest, diff --git a/internal/mocks/inmem/ccipreader_inmem.go b/internal/mocks/inmem/ccipreader_inmem.go index dbe671443..93fa42d8b 100644 --- a/internal/mocks/inmem/ccipreader_inmem.go +++ b/internal/mocks/inmem/ccipreader_inmem.go @@ -162,7 +162,9 @@ func (r InMemoryCCIPReader) GetChainFeePriceUpdate( return nil } -func (r InMemoryCCIPReader) DiscoverContracts(ctx context.Context) (reader.ContractAddresses, error) { +func (r InMemoryCCIPReader) DiscoverContracts( + ctx context.Context, + allChains []cciptypes.ChainSelector) (reader.ContractAddresses, error) { return nil, nil } diff --git a/internal/plugincommon/discovery/processor.go b/internal/plugincommon/discovery/processor.go index bc8c6afde..4d1e3ea9a 100644 --- a/internal/plugincommon/discovery/processor.go +++ b/internal/plugincommon/discovery/processor.go @@ -6,6 +6,7 @@ import ( "github.com/smartcontractkit/libocr/commontypes" ragep2ptypes "github.com/smartcontractkit/libocr/ragep2p/types" + "golang.org/x/exp/maps" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -68,8 +69,12 @@ func (cdp *ContractDiscoveryProcessor) Observation( return dt.Observation{}, fmt.Errorf("unable to get fchain: %w, seqNr: %d", err, seqNr) } - // TODO: discover the full list of source chain selectors and pass it into DiscoverContracts. - contracts, err := (*cdp.reader).DiscoverContracts(ctx) + chainConfigs, err := cdp.homechain.GetAllChainConfigs() + if err != nil { + return dt.Observation{}, fmt.Errorf("unable to get chain configs: %w, seqNr: %d", err, seqNr) + } + + contracts, err := (*cdp.reader).DiscoverContracts(ctx, maps.Keys(chainConfigs)) if err != nil { return dt.Observation{}, fmt.Errorf("unable to discover contracts: %w, seqNr: %d", err, seqNr) } diff --git a/internal/plugincommon/discovery/processor_test.go b/internal/plugincommon/discovery/processor_test.go index 06e0ab443..0b56b4233 100644 --- a/internal/plugincommon/discovery/processor_test.go +++ b/internal/plugincommon/discovery/processor_test.go @@ -64,10 +64,11 @@ func TestContractDiscoveryProcessor_Observation_SupportsDest_HappyPath(t *testin } mockReader. EXPECT(). - DiscoverContracts(mock.Anything). + DiscoverContracts(mock.Anything, mock.Anything). Return(expectedContracts, nil) mockHomeChain.EXPECT().GetFChain().Return(expectedFChain, nil) + mockHomeChain.EXPECT().GetAllChainConfigs().Return(nil, nil) defer mockReader.AssertExpectations(t) defer mockHomeChain.AssertExpectations(t) @@ -140,10 +141,11 @@ func TestContractDiscoveryProcessor_Observation_SourceReadersNotReady(t *testing } mockReader. EXPECT(). - DiscoverContracts(mock.Anything). + DiscoverContracts(mock.Anything, mock.Anything). Return(nil, nil) mockHomeChain.EXPECT().GetFChain().Return(expectedFChain, nil) + mockHomeChain.EXPECT().GetAllChainConfigs().Return(nil, nil) defer mockReader.AssertExpectations(t) defer mockHomeChain.AssertExpectations(t) @@ -179,9 +181,10 @@ func TestContractDiscoveryProcessor_Observation_ErrorDiscoveringContracts(t *tes discoveryErr := fmt.Errorf("discovery error") mockReader. EXPECT(). - DiscoverContracts(mock.Anything). + DiscoverContracts(mock.Anything, mock.Anything). Return(nil, discoveryErr) mockHomeChain.EXPECT().GetFChain().Return(expectedFChain, nil) + mockHomeChain.EXPECT().GetAllChainConfigs().Return(nil, nil) defer mockReader.AssertExpectations(t) defer mockHomeChain.AssertExpectations(t) diff --git a/mocks/pkg/reader/ccip_reader.go b/mocks/pkg/reader/ccip_reader.go index 60042230b..ffe77a4f5 100644 --- a/mocks/pkg/reader/ccip_reader.go +++ b/mocks/pkg/reader/ccip_reader.go @@ -95,9 +95,9 @@ func (_c *MockCCIPReader_CommitReportsGTETimestamp_Call) RunAndReturn(run func(c return _c } -// DiscoverContracts provides a mock function with given fields: ctx -func (_m *MockCCIPReader) DiscoverContracts(ctx context.Context) (reader.ContractAddresses, error) { - ret := _m.Called(ctx) +// DiscoverContracts provides a mock function with given fields: ctx, allChains +func (_m *MockCCIPReader) DiscoverContracts(ctx context.Context, allChains []ccipocr3.ChainSelector) (reader.ContractAddresses, error) { + ret := _m.Called(ctx, allChains) if len(ret) == 0 { panic("no return value specified for DiscoverContracts") @@ -105,19 +105,19 @@ func (_m *MockCCIPReader) DiscoverContracts(ctx context.Context) (reader.Contrac var r0 reader.ContractAddresses var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (reader.ContractAddresses, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, []ccipocr3.ChainSelector) (reader.ContractAddresses, error)); ok { + return rf(ctx, allChains) } - if rf, ok := ret.Get(0).(func(context.Context) reader.ContractAddresses); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, []ccipocr3.ChainSelector) reader.ContractAddresses); ok { + r0 = rf(ctx, allChains) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(reader.ContractAddresses) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, []ccipocr3.ChainSelector) error); ok { + r1 = rf(ctx, allChains) } else { r1 = ret.Error(1) } @@ -132,13 +132,14 @@ type MockCCIPReader_DiscoverContracts_Call struct { // DiscoverContracts is a helper method to define mock.On call // - ctx context.Context -func (_e *MockCCIPReader_Expecter) DiscoverContracts(ctx interface{}) *MockCCIPReader_DiscoverContracts_Call { - return &MockCCIPReader_DiscoverContracts_Call{Call: _e.mock.On("DiscoverContracts", ctx)} +// - allChains []ccipocr3.ChainSelector +func (_e *MockCCIPReader_Expecter) DiscoverContracts(ctx interface{}, allChains interface{}) *MockCCIPReader_DiscoverContracts_Call { + return &MockCCIPReader_DiscoverContracts_Call{Call: _e.mock.On("DiscoverContracts", ctx, allChains)} } -func (_c *MockCCIPReader_DiscoverContracts_Call) Run(run func(ctx context.Context)) *MockCCIPReader_DiscoverContracts_Call { +func (_c *MockCCIPReader_DiscoverContracts_Call) Run(run func(ctx context.Context, allChains []ccipocr3.ChainSelector)) *MockCCIPReader_DiscoverContracts_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].([]ccipocr3.ChainSelector)) }) return _c } @@ -148,7 +149,7 @@ func (_c *MockCCIPReader_DiscoverContracts_Call) Return(_a0 reader.ContractAddre return _c } -func (_c *MockCCIPReader_DiscoverContracts_Call) RunAndReturn(run func(context.Context) (reader.ContractAddresses, error)) *MockCCIPReader_DiscoverContracts_Call { +func (_c *MockCCIPReader_DiscoverContracts_Call) RunAndReturn(run func(context.Context, []ccipocr3.ChainSelector) (reader.ContractAddresses, error)) *MockCCIPReader_DiscoverContracts_Call { _c.Call.Return(run) return _c } diff --git a/pkg/contractreader/extended_unit_test.go b/pkg/contractreader/extended_unit_test.go index fd8f87993..93532919a 100644 --- a/pkg/contractreader/extended_unit_test.go +++ b/pkg/contractreader/extended_unit_test.go @@ -5,9 +5,8 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" - "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/stretchr/testify/assert" ) func TestGetOneBinding(t *testing.T) { diff --git a/pkg/reader/ccip.go b/pkg/reader/ccip.go index d7e4172b8..059dcf1f2 100644 --- a/pkg/reader/ccip.go +++ b/pkg/reader/ccip.go @@ -13,13 +13,11 @@ import ( "time" mapset "github.com/deckarep/golang-set/v2" - "golang.org/x/exp/maps" - "golang.org/x/sync/errgroup" - "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query" "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + "golang.org/x/exp/maps" "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon/consensus" @@ -445,7 +443,7 @@ func (r *ccipChainReader) NextSeqNum( ctx context.Context, chains []cciptypes.ChainSelector, ) (map[cciptypes.ChainSelector]cciptypes.SeqNum, error) { lggr := logutil.WithContextValues(ctx, r.lggr) - cfgs, err := r.getOffRampSourceChainsConfig(ctx, chains) + cfgs, err := r.getOffRampSourceChainsConfig(ctx, lggr, chains) if err != nil { return nil, fmt.Errorf("get source chains config: %w", err) } @@ -855,6 +853,7 @@ func chainSelectorToBytes16(chainSel cciptypes.ChainSelector) [16]byte { func (r *ccipChainReader) discoverOffRampContracts( ctx context.Context, lggr logger.Logger, + chains []cciptypes.ChainSelector, ) (ContractAddresses, error) { // Exit without an error if we cannot read the destination. if err := validateExtendedReaderExistence(r.contractReaders, r.destChain); err != nil { @@ -866,7 +865,8 @@ func (r *ccipChainReader) discoverOffRampContracts( // OnRamps are in the offRamp SourceChainConfig. { - sourceConfigs, err := r.getAllOffRampSourceChainsConfig(ctx, lggr) + sourceConfigs, err := r.getOffRampSourceChainsConfig(ctx, lggr, chains) + if err != nil { return nil, fmt.Errorf("unable to get SourceChainsConfig: %w", err) } @@ -923,13 +923,15 @@ func (r *ccipChainReader) discoverOffRampContracts( return resp, nil } -func (r *ccipChainReader) DiscoverContracts(ctx context.Context) (ContractAddresses, error) { +func (r *ccipChainReader) DiscoverContracts(ctx context.Context, + chains []cciptypes.ChainSelector, +) (ContractAddresses, error) { + var resp ContractAddresses lggr := logutil.WithContextValues(ctx, r.lggr) - resp := make(ContractAddresses) // Discover destination contracts if the dest chain is supported. if err := validateExtendedReaderExistence(r.contractReaders, r.destChain); err == nil { - resp, err = r.discoverOffRampContracts(ctx, lggr) + resp, err = r.discoverOffRampContracts(ctx, lggr, chains) // Can't continue with discovery if the destination chain is not available. // We read source chains OnRamps from there, and onRamps are essential for feeQuoter and Router discovery. if err != nil { @@ -1146,117 +1148,76 @@ func (scc sourceChainConfig) check() (bool /* enabled */, error) { return scc.IsEnabled, nil } -// getOffRampSourceChainsConfig returns the offRamp contract's source chain configurations for each supported source -// chain. If some chain is disabled it is not included in the response. +// getOffRampSourceChainsConfig get all enabled source chain configs from the offRamp for dest chain func (r *ccipChainReader) getOffRampSourceChainsConfig( - ctx context.Context, chains []cciptypes.ChainSelector) (map[cciptypes.ChainSelector]sourceChainConfig, error) { + ctx context.Context, + lggr logger.Logger, + chains []cciptypes.ChainSelector, +) (map[cciptypes.ChainSelector]sourceChainConfig, error) { if err := validateExtendedReaderExistence(r.contractReaders, r.destChain); err != nil { - return nil, err + return nil, fmt.Errorf("validate extended reader existence: %w", err) } - res := make(map[cciptypes.ChainSelector]sourceChainConfig) - mu := new(sync.Mutex) + configs := make(map[cciptypes.ChainSelector]sourceChainConfig) + contractBatch := make(types.ContractBatch, 0, len(chains)) + sourceChains := make([]any, 0, len(chains)) - eg := new(errgroup.Group) - for _, chainSel := range chains { - if chainSel == r.destChain { + for _, chain := range chains { + if chain == r.destChain { continue } + sourceChains = append(sourceChains, chain) - // TODO: look into using BatchGetLatestValue instead to simplify concurrency? - eg.Go(func() error { - resp := sourceChainConfig{} - err := r.contractReaders[r.destChain].ExtendedGetLatestValue( - ctx, - consts.ContractNameOffRamp, - consts.MethodNameGetSourceChainConfig, - primitives.Unconfirmed, - map[string]any{ - "sourceChainSelector": chainSel, - }, - &resp, - ) - if err != nil { - return fmt.Errorf("failed to get source chain config for source chain %d: %w", - chainSel, err) - } - - enabled, err := resp.check() - if err != nil { - return fmt.Errorf("source chain config check for chain %d failed: %w", chainSel, err) - } - if !enabled { - // We don't want to process disabled chains prematurely. - r.lggr.Debugw("source chain is disabled", "chain", chainSel) - return nil - } - - mu.Lock() - res[chainSel] = resp - mu.Unlock() - return nil + contractBatch = append(contractBatch, types.BatchRead{ + ReadName: consts.MethodNameGetSourceChainConfig, + Params: map[string]any{ + "sourceChainSelector": chain, + }, + ReturnVal: new(sourceChainConfig), }) } - if err := eg.Wait(); err != nil { - return nil, err - } - return res, nil -} - -// selectorsAndConfigs wraps the return values from getAllSourceChainConfigs. -type selectorsAndConfigs struct { - Selectors []uint64 `mapstructure:"F0"` - SourceChainConfigs []sourceChainConfig `mapstructure:"F1"` -} - -// getAllOffRampSourceChainsConfig get all enabled source chain configs from the offRamp for dest chain -func (r *ccipChainReader) getAllOffRampSourceChainsConfig( - ctx context.Context, - lggr logger.Logger, -) (map[cciptypes.ChainSelector]sourceChainConfig, error) { - if err := validateExtendedReaderExistence(r.contractReaders, r.destChain); err != nil { - return nil, fmt.Errorf("validate extended reader existence: %w", err) - } - - configs := make(map[cciptypes.ChainSelector]sourceChainConfig) - - var resp selectorsAndConfigs - err := r.contractReaders[r.destChain].ExtendedGetLatestValue( - ctx, - consts.ContractNameOffRamp, - consts.MethodNameOffRampGetAllSourceChainConfigs, - primitives.Unconfirmed, - map[string]any{}, - &resp, + results, _, err := r.contractReaders[r.destChain].ExtendedBatchGetLatestValues( + ctx, contractreader.ExtendedBatchGetLatestValuesRequest{consts.ContractNameOffRamp: contractBatch}, + false, ) + if err != nil { - return nil, fmt.Errorf("failed to get source chain configs for source chain %d: %w", + return nil, fmt.Errorf("failed to get source chain configs for dest chain %d: %w", r.destChain, err) } - if len(resp.SourceChainConfigs) != len(resp.Selectors) { - return nil, fmt.Errorf("selectors and source chain configs length mismatch: %v", resp) - } - - lggr.Debugw("got source chain configs", "configs", resp) + lggr.Debugw("got source chain configs", "configs", results) // Populate the map. - for i := range resp.Selectors { - chainSel := cciptypes.ChainSelector(resp.Selectors[i]) - cfg := resp.SourceChainConfigs[i] - - enabled, err := cfg.check() - if err != nil { - return nil, fmt.Errorf("source chain config check for chain %d failed: %w", chainSel, err) - } - if !enabled { - // We don't want to process disabled chains prematurely. - lggr.Debugw("source chain is disabled", "chain", chainSel) - continue + for _, readResult := range results { + if len(readResult) != len(sourceChains) { + return nil, fmt.Errorf("selectors and source chain configs length mismatch: sourceChains=%v, configs=%v", + sourceChains, results) } + for i, chainSel := range sourceChains { + v, err := readResult[i].GetResult() + if err != nil { + return nil, fmt.Errorf("GetSourceChainConfig for chainSelector=%d failed: %w", chainSel, err) + } - configs[chainSel] = cfg + cfg, ok := v.(*sourceChainConfig) + if !ok { + return nil, fmt.Errorf("invalid result type from GetSourceChainConfig for chainSelector=%d: %w", chainSel, err) + } + + enabled, err := cfg.check() + if err != nil { + return nil, fmt.Errorf("source chain config check for chain %d failed: %w", chainSel, err) + } + if !enabled { + // We don't want to process disabled chains prematurely. + lggr.Debugw("source chain is disabled", "chain", chainSel) + continue + } + + configs[chainSel.(cciptypes.ChainSelector)] = *cfg + } } return configs, nil diff --git a/pkg/reader/ccip_interface.go b/pkg/reader/ccip_interface.go index 4a8fb3d90..9d2ac262a 100644 --- a/pkg/reader/ccip_interface.go +++ b/pkg/reader/ccip_interface.go @@ -150,8 +150,11 @@ type CCIPReader interface { // from the destination chain RMN remote contract. Caller should be able to access destination. GetRmnCurseInfo(ctx context.Context, sourceChainSelectors []cciptypes.ChainSelector) (*CurseInfo, error) - // DiscoverContracts reads from all available contract readers to discover contract addresses. - DiscoverContracts(ctx context.Context) (ContractAddresses, error) + // DiscoverContracts reads the destination chain for contract addresses. They are returned per + // contract and source chain selector. + // allChains is needed because there is no way to enumerate all chain selectors on Solana. We'll attempt to + // fetch the source config from the offramp for each of them. + DiscoverContracts(ctx context.Context, allChains []cciptypes.ChainSelector) (ContractAddresses, error) // LinkPriceUSD gets the LINK price in 1e-18 USDs from the FeeQuoter contract on the destination chain. // For example, if the price is 1 LINK = 10 USD, this function will return 10e18 (10 * 1e18). You can think of this diff --git a/pkg/reader/ccip_test.go b/pkg/reader/ccip_test.go index 1aa50dd11..e786bc751 100644 --- a/pkg/reader/ccip_test.go +++ b/pkg/reader/ccip_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" + "golang.org/x/exp/maps" "github.com/smartcontractkit/chainlink-ccip/mocks/pkg/types/ccipocr3" @@ -47,30 +48,36 @@ func TestCCIPChainReader_getSourceChainsConfig(t *testing.T) { destCR := reader_mocks.NewMockContractReaderFacade(t) destCR.EXPECT().Bind(mock.Anything, mock.Anything).Return(nil) destCR.EXPECT().HealthReport().Return(nil) - destCR.EXPECT().GetLatestValue( - mock.Anything, - mock.Anything, - mock.Anything, + destCR.EXPECT().BatchGetLatestValues( mock.Anything, mock.Anything, - ).Run(func( + ).RunAndReturn(func( ctx context.Context, - readIdentifier string, - confidenceLevel primitives.ConfidenceLevel, - params interface{}, - returnVal interface{}, - ) { - sourceChain := params.(map[string]any)["sourceChainSelector"].(cciptypes.ChainSelector) - v := returnVal.(*sourceChainConfig) - - fromString, err := cciptypes.NewBytesFromString(fmt.Sprintf( - "0x%d000000000000000000000000000000000000000", sourceChain), - ) - require.NoError(t, err) - v.OnRamp = cciptypes.UnknownAddress(fromString) - v.IsEnabled = true - v.Router = cciptypes.UnknownAddress(fromString) - }).Return(nil) + request types.BatchGetLatestValuesRequest, + ) (types.BatchGetLatestValuesResult, error) { + results := make(types.BatchGetLatestValuesResult, 0) + for contractName, batch := range request { + for _, readReq := range batch { + res := types.BatchReadResult{ + ReadName: readReq.ReadName, + } + params := readReq.Params.(map[string]any) + sourceChain := params["sourceChainSelector"].(cciptypes.ChainSelector) + v := readReq.ReturnVal.(*sourceChainConfig) + + fromString, err := cciptypes.NewBytesFromString(fmt.Sprintf( + "0x%d000000000000000000000000000000000000000", sourceChain), + ) + require.NoError(t, err) + v.OnRamp = cciptypes.UnknownAddress(fromString) + v.IsEnabled = true + v.Router = fromString + res.SetResult(v, nil) + results[contractName] = append(results[contractName], res) + } + } + return results, nil + }) offrampAddress := []byte{0x3} ccipReader := newCCIPChainReaderInternal( @@ -93,7 +100,7 @@ func TestCCIPChainReader_getSourceChainsConfig(t *testing.T) { Address: typeconv.AddressBytesToString(offrampAddress, 111_111)}})) ctx := context.Background() - cfgs, err := ccipReader.getOffRampSourceChainsConfig(ctx, []cciptypes.ChainSelector{chainA, chainB}) + cfgs, err := ccipReader.getOffRampSourceChainsConfig(ctx, logger.Test(t), []cciptypes.ChainSelector{chainA, chainB}) assert.NoError(t, err) assert.Len(t, cfgs, 2) assert.Equal(t, "0x1000000000000000000000000000000000000000", cfgs[chainA].OnRamp.String()) @@ -432,9 +439,19 @@ func TestCCIPChainReader_DiscoverContracts_HappyPath_Round1(t *testing.T) { destRMNRemote := []byte{0x4} destFeeQuoter := []byte{0x5} destRouter := []byte{0x6} - //srcRouters := []byte{0x7, 0x8} + srcRouters := [][]byte{{0x7}, {0x8}} //srcFeeQuoters := [2][]byte{{0x7}, {0x8}} + sourceChainConfigs := make(map[cciptypes.ChainSelector]sourceChainConfig, len(sourceChain)) + for i, chain := range sourceChain { + sourceChainConfigs[chain] = sourceChainConfig{ + Router: srcRouters[i], + IsEnabled: true, + MinSeqNr: 0, + OnRamp: onramps[i], + } + } + // Build expected addresses. var expectedContractAddresses ContractAddresses // Source FeeQuoter's and destRouter are missing. @@ -452,29 +469,22 @@ func TestCCIPChainReader_DiscoverContracts_HappyPath_Round1(t *testing.T) { mockReaders[destChain] = reader_mocks.NewMockExtended(t) addDestinationContractAssertions(mockReaders[destChain], destNonceMgr, destRMNRemote, destFeeQuoter) - mockReaders[destChain].EXPECT().ExtendedGetLatestValue( + mockReaders[destChain].EXPECT().ExtendedBatchGetLatestValues( mock.Anything, - consts.ContractNameOffRamp, - consts.MethodNameOffRampGetAllSourceChainConfigs, - primitives.Unconfirmed, - map[string]any{}, mock.Anything, - ).Return(nil).Run(withReturnValueOverridden(func(returnVal interface{}) { - v := returnVal.(*selectorsAndConfigs) - v.Selectors = []uint64{uint64(sourceChain[0]), uint64(sourceChain[1])} - v.SourceChainConfigs = []sourceChainConfig{ - { - OnRamp: onramps[0], - Router: destRouter, - IsEnabled: true, - }, - { - OnRamp: onramps[1], - Router: destRouter, - IsEnabled: true, - }, - } - })) + mock.Anything, + ).RunAndReturn(withBatchGetLatestValuesRetValues(t, + "0x1234567890123456789012345678901234567890", + []any{&sourceChainConfig{ + OnRamp: onramps[0], + Router: destRouter, + IsEnabled: true, + }, &sourceChainConfig{ + OnRamp: onramps[1], + Router: destRouter, + IsEnabled: true, + }}, + )) // mock calls to get fee quoter from onramps and source chain config from offramp. for _, selector := range sourceChain { @@ -515,7 +525,7 @@ func TestCCIPChainReader_DiscoverContracts_HappyPath_Round1(t *testing.T) { lggr: lggr, } - contractAddresses, err := ccipChainReader.DiscoverContracts(ctx) + contractAddresses, err := ccipChainReader.DiscoverContracts(ctx, sourceChain[:]) require.NoError(t, err) assert.Equal(t, expectedContractAddresses, contractAddresses) @@ -595,29 +605,23 @@ func TestCCIPChainReader_DiscoverContracts_HappyPath_Round2(t *testing.T) { mockReaders[destChain] = reader_mocks.NewMockExtended(t) addDestinationContractAssertions(mockReaders[destChain], destNonceMgr, destRMNRemote, destFeeQuoter) - mockReaders[destChain].EXPECT().ExtendedGetLatestValue( + mockReaders[destChain].EXPECT().ExtendedBatchGetLatestValues( mock.Anything, - consts.ContractNameOffRamp, - consts.MethodNameOffRampGetAllSourceChainConfigs, - primitives.Unconfirmed, - map[string]any{}, mock.Anything, - ).Return(nil).Run(withReturnValueOverridden(func(returnVal interface{}) { - v := returnVal.(*selectorsAndConfigs) - v.Selectors = []uint64{uint64(sourceChain[0]), uint64(sourceChain[1])} - v.SourceChainConfigs = []sourceChainConfig{ - { - OnRamp: onramps[0], - Router: destRouter[0], - IsEnabled: true, - }, - { + mock.Anything, + ).RunAndReturn(withBatchGetLatestValuesRetValues(t, + "0x1234567890123456789012345678901234567890", + []any{&sourceChainConfig{ + OnRamp: onramps[0], + Router: destRouter[0], + IsEnabled: true, + }, + &sourceChainConfig{ OnRamp: onramps[1], Router: destRouter[1], IsEnabled: true, }, - } - })) + })) // mock calls to get fee quoter from onramps and source chain config from offramp. for i, selector := range sourceChain { @@ -662,7 +666,7 @@ func TestCCIPChainReader_DiscoverContracts_HappyPath_Round2(t *testing.T) { lggr: logger.Test(t), } - contractAddresses, err := ccipChainReader.DiscoverContracts(ctx) + contractAddresses, err := ccipChainReader.DiscoverContracts(ctx, sourceChain[:]) require.NoError(t, err) require.Equal(t, expectedContractAddresses, contractAddresses) @@ -677,14 +681,11 @@ func TestCCIPChainReader_DiscoverContracts_GetAllSourceChainConfig_Errors(t *tes // mock the call for sourceChain2 - failure getLatestValueErr := errors.New("some error") - destExtended.EXPECT().ExtendedGetLatestValue( + destExtended.EXPECT().ExtendedBatchGetLatestValues( mock.Anything, - consts.ContractNameOffRamp, - consts.MethodNameOffRampGetAllSourceChainConfigs, - primitives.Unconfirmed, - map[string]any{}, mock.Anything, - ).Return(getLatestValueErr) + mock.Anything, + ).Return(nil, nil, getLatestValueErr) // get static config call won't occur because the source chain config call failed. @@ -702,7 +703,7 @@ func TestCCIPChainReader_DiscoverContracts_GetAllSourceChainConfig_Errors(t *tes lggr: logger.Test(t), } - _, err := ccipChainReader.DiscoverContracts(ctx) + _, err := ccipChainReader.DiscoverContracts(ctx, []cciptypes.ChainSelector{sourceChain1, sourceChain2}) require.Error(t, err) require.ErrorIs(t, err, getLatestValueErr) } @@ -715,15 +716,15 @@ func TestCCIPChainReader_DiscoverContracts_GetOfframpStaticConfig_Errors(t *test destExtended := reader_mocks.NewMockExtended(t) // mock the call for source chain configs - destExtended.EXPECT().ExtendedGetLatestValue( + destExtended.EXPECT().ExtendedBatchGetLatestValues( mock.Anything, - consts.ContractNameOffRamp, - consts.MethodNameOffRampGetAllSourceChainConfigs, - primitives.Unconfirmed, - map[string]any{}, mock.Anything, - ).Return(nil) // doesn't matter for this test - // mock the call to get the nonce manager - failure + mock.Anything, + ).RunAndReturn(withBatchGetLatestValuesRetValues(t, + "0x1234567890123456789012345678901234567890", + []any{&sourceChainConfig{}, &sourceChainConfig{}})) + + // mock the call to get the static config - failure getLatestValueErr := errors.New("some error") destExtended.EXPECT().ExtendedGetLatestValue( mock.Anything, @@ -748,7 +749,7 @@ func TestCCIPChainReader_DiscoverContracts_GetOfframpStaticConfig_Errors(t *test lggr: logger.Test(t), } - _, err := ccipChainReader.DiscoverContracts(ctx) + _, err := ccipChainReader.DiscoverContracts(ctx, []cciptypes.ChainSelector{sourceChain1, sourceChain2}) require.Error(t, err) require.ErrorIs(t, err, getLatestValueErr) } @@ -771,6 +772,42 @@ func withReturnValueOverridden(mapper func(returnVal interface{})) func(ctx cont } } +// withBatchGetLatestValuesRetValues returns a mock ExtendedBatchGetLatestValues() method +// which can be passed to RunAndReturn(), given a set of return values and an address as input +// Only supports a single contract +func withBatchGetLatestValuesRetValues( + t testing.TB, + address string, + retVals []any) func( + context.Context, + contractreader.ExtendedBatchGetLatestValuesRequest, + bool) (types.BatchGetLatestValuesResult, []string, error) { + return func( + ctx context.Context, req contractreader.ExtendedBatchGetLatestValuesRequest, graceful bool, + ) (types.BatchGetLatestValuesResult, []string, error) { + require.GreaterOrEqual(t, len(retVals), 1) + _, ok := retVals[0].(*sourceChainConfig) + require.True(t, ok) + require.Len(t, req, 1) + contract := maps.Keys(req)[0] + batchRequest := maps.Values(req)[0] + require.Equal(t, len(retVals), len(batchRequest)) + + results := make(types.ContractBatchResults, 0, len(retVals)) + for i, retVal := range retVals { + res := types.BatchReadResult{ReadName: batchRequest[i].ReadName} + res.SetResult(retVal, nil) + results = append(results, res) + } + boundContract := types.BoundContract{ + Address: address, + Name: contract, + } + + return types.BatchGetLatestValuesResult{boundContract: results}, nil, nil + } +} + func TestCCIPChainReader_getDestFeeQuoterStaticConfig(t *testing.T) { destCR := reader_mocks.NewMockContractReaderFacade(t) destCR.EXPECT().Bind(mock.Anything, mock.Anything).Return(nil)