Skip to content

Commit

Permalink
Add samples per second logging for reverb_dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hartikainen committed Jun 20, 2023
1 parent 98c4204 commit 9694646
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions acme/datasets/reverb_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import time
from typing import Sequence
import itertools

from absl import app
from absl import logging
Expand Down Expand Up @@ -75,22 +76,29 @@ def main(_):
next_timestep = environment.step(action)
adder.add(action, next_timestep, extras=())

for batch_size in [256, 256 * 8, 256 * 64]:
for prefetch_size in [0, 1, 4]:
print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}')
ds = datasets.make_reverb_dataset(
table='default',
server_address=replay_client.server_address,
batch_size=batch_size,
prefetch_size=prefetch_size,
)
it = ds.as_numpy_iterator()

for iteration in range(3):
t = time.time()
for _ in range(1000):
_ = next(it)
print(f'Iteration {iteration} finished in {time.time() - t}s')
batch_sizes = [256, 256 * 8, 256 * 64]
prefetch_sizes = [0, 1, 4]
num_batches_per_iteration = 1000

for batch_size, prefetch_size in itertools.product(batch_sizes, prefetch_sizes):
print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}')
ds = datasets.make_reverb_dataset(
table='default',
server_address=replay_client.server_address,
batch_size=batch_size,
prefetch_size=prefetch_size,
)
it = ds.as_numpy_iterator()

for iteration in range(3):
start = time.time()
for _ in range(num_batches_per_iteration):
_ = next(it)
end = time.time()
duration_s = end - start
samples_per_second = batch_size * num_batches_per_iteration / duration_s
print(f'Iteration {iteration} finished in {duration_s:_.02}s with '
f'{samples_per_second:_.2f} samples/s.')


if __name__ == '__main__':
Expand Down

0 comments on commit 9694646

Please sign in to comment.