From e6af314ab34acb4b7a4e9c2c5ab90510ebb638cf Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Tue, 20 Jun 2023 12:42:47 +0000 Subject: [PATCH] Add samples per second logging for `reverb_dataset.py` --- acme/datasets/reverb_benchmark.py | 40 ++++++++++++++++++------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/acme/datasets/reverb_benchmark.py b/acme/datasets/reverb_benchmark.py index 65af786f9a..71c95d44eb 100644 --- a/acme/datasets/reverb_benchmark.py +++ b/acme/datasets/reverb_benchmark.py @@ -19,6 +19,7 @@ import time from typing import Sequence +import itertools from absl import app from absl import logging @@ -75,22 +76,29 @@ def main(_): next_timestep = environment.step(action) adder.add(action, next_timestep, extras=()) - for batch_size in [256, 256 * 8, 256 * 64]: - for prefetch_size in [0, 1, 4]: - print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') - ds = datasets.make_reverb_dataset( - table='default', - server_address=replay_client.server_address, - batch_size=batch_size, - prefetch_size=prefetch_size, - ) - it = ds.as_numpy_iterator() - - for iteration in range(3): - t = time.time() - for _ in range(1000): - _ = next(it) - print(f'Iteration {iteration} finished in {time.time() - t}s') + batch_sizes = [256, 256 * 8, 256 * 64] + prefetch_sizes = [0, 1, 4] + num_batches_per_iteration = 1000 + + for batch_size, prefetch_size in itertools.product(batch_sizes, prefetch_sizes): + print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') + ds = datasets.make_reverb_dataset( + table='default', + server_address=replay_client.server_address, + batch_size=batch_size, + prefetch_size=prefetch_size, + ) + it = ds.as_numpy_iterator() + + for iteration in range(3): + start = time.time() + for _ in range(num_batches_per_iteration): + _ = next(it) + end = time.time() + duration_s = end - start + samples_per_second = batch_size * num_batches_per_iteration / duration_s + print(f'Iteration {iteration} finished in {duration_s:_.02f}s with ' + f'{samples_per_second:_.02f} samples/s.') if __name__ == '__main__':