I ran these on the following hardware:
- Intel Xeon E5-2650 v4 @ 2.20 GHz
- 512GB DDR4 memory (not that we would need it)
- NVidia Tesla P100 (16GB memory)
Software stack:
-
CentOS 7
-
GNU compiler tookit 8.3.0
-
Python 3.8.12
-
CUDA 11.2
-
Packages pulled from
pip
-
Backend versions:
aesara==2.2.4 cupy==9.5.0 jax==0.2.24 numba==0.54.1 numpy==1.19.5 pytorch==1.10.0 tensorflow==2.6.0
An equation consisting of >100 terms with no data dependencies and only elementary math. This benchmark should represent a best-case scenario for vector instructions and GPU performance.
$ taskset -c 23 python run.py benchmarks/equation_of_state/
benchmarks.equation_of_state
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 pytorch 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.003 6.188
4,096 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 4.581
4,096 numba 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.001 2.808
4,096 aesara 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.004 2.517
4,096 tensorflow 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.004 2.507
4,096 numpy 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.005 1.000
16,384 pytorch 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.002 4.947
16,384 jax 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.002 4.193
16,384 tensorflow 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.006 3.598
16,384 numba 10,000 0.003 0.000 0.003 0.003 0.003 0.003 0.006 2.861
16,384 aesara 1,000 0.003 0.000 0.003 0.003 0.003 0.003 0.003 2.734
16,384 numpy 1,000 0.008 0.000 0.008 0.008 0.008 0.008 0.008 1.000
65,536 pytorch 1,000 0.007 0.000 0.007 0.007 0.007 0.007 0.007 4.655
65,536 tensorflow 1,000 0.007 0.000 0.007 0.007 0.007 0.007 0.007 4.405
65,536 jax 1,000 0.008 0.000 0.008 0.008 0.008 0.008 0.012 3.997
65,536 numba 1,000 0.011 0.000 0.011 0.011 0.011 0.011 0.011 2.954
65,536 aesara 1,000 0.011 0.000 0.011 0.011 0.011 0.011 0.011 2.899
65,536 numpy 100 0.032 0.000 0.032 0.032 0.032 0.032 0.032 1.000
262,144 pytorch 1,000 0.023 0.000 0.023 0.023 0.023 0.023 0.024 5.570
262,144 tensorflow 1,000 0.024 0.000 0.023 0.024 0.024 0.024 0.024 5.501
262,144 jax 100 0.024 0.000 0.024 0.024 0.024 0.024 0.025 5.423
262,144 numba 100 0.039 0.000 0.039 0.039 0.039 0.039 0.039 3.319
262,144 aesara 100 0.040 0.000 0.040 0.040 0.040 0.040 0.040 3.280
262,144 numpy 100 0.130 0.001 0.129 0.130 0.130 0.131 0.131 1.000
1,048,576 pytorch 100 0.092 0.000 0.092 0.092 0.092 0.092 0.092 7.118
1,048,576 jax 100 0.103 0.000 0.103 0.103 0.103 0.104 0.104 6.308
1,048,576 tensorflow 100 0.105 0.000 0.105 0.105 0.105 0.105 0.105 6.210
1,048,576 numba 100 0.161 0.000 0.160 0.161 0.161 0.161 0.161 4.060
1,048,576 aesara 100 0.164 0.000 0.163 0.164 0.164 0.164 0.164 3.987
1,048,576 numpy 10 0.653 0.002 0.647 0.653 0.653 0.654 0.654 1.000
4,194,304 pytorch 10 0.388 0.000 0.388 0.388 0.388 0.388 0.388 9.440
4,194,304 jax 10 0.397 0.000 0.397 0.397 0.397 0.398 0.398 9.220
4,194,304 tensorflow 10 0.418 0.000 0.418 0.418 0.418 0.418 0.419 8.755
4,194,304 numba 10 0.630 0.001 0.629 0.630 0.630 0.631 0.631 5.812
4,194,304 aesara 10 0.647 0.000 0.647 0.647 0.647 0.647 0.648 5.659
4,194,304 numpy 10 3.662 0.002 3.659 3.661 3.663 3.664 3.665 1.000
(time in wall seconds, less is better)
$ taskset -c 23 python run.py benchmarks/equation_of_state/ -s 16777216
benchmarks.equation_of_state
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
16,777,216 pytorch 10 1.380 0.009 1.372 1.375 1.376 1.383 1.402 10.413
16,777,216 tensorflow 10 1.665 0.001 1.664 1.665 1.665 1.665 1.667 8.628
16,777,216 jax 10 1.737 0.001 1.736 1.737 1.737 1.738 1.740 8.270
16,777,216 numba 10 2.436 0.002 2.432 2.436 2.436 2.437 2.438 5.899
16,777,216 aesara 10 2.549 0.001 2.548 2.549 2.549 2.549 2.553 5.636
16,777,216 numpy 10 14.369 0.004 14.362 14.366 14.368 14.371 14.377 1.000
(time in wall seconds, less is better)
$ for backend in cupy jax pytorch tensorflow; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/equation_of_state/ --gpu -b $backend -b numpy; done
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 10,000 0.002 0.001 0.002 0.002 0.002 0.002 0.016 1.000
4,096 cupy 1,000 0.007 0.001 0.006 0.007 0.007 0.007 0.020 0.273
16,384 cupy 1,000 0.007 0.001 0.007 0.007 0.007 0.007 0.020 1.190
16,384 numpy 1,000 0.008 0.002 0.007 0.008 0.008 0.008 0.022 1.000
65,536 cupy 1,000 0.007 0.001 0.007 0.007 0.007 0.007 0.020 6.930
65,536 numpy 100 0.047 0.004 0.032 0.043 0.050 0.050 0.052 1.000
262,144 cupy 1,000 0.007 0.001 0.007 0.007 0.007 0.007 0.021 29.436
262,144 numpy 100 0.200 0.008 0.125 0.198 0.203 0.203 0.205 1.000
1,048,576 cupy 100 0.016 0.000 0.016 0.016 0.016 0.016 0.017 49.831
1,048,576 numpy 10 0.811 0.001 0.810 0.810 0.811 0.811 0.812 1.000
4,194,304 cupy 100 0.061 0.000 0.061 0.061 0.061 0.061 0.062 60.944
4,194,304 numpy 10 3.694 0.002 3.691 3.693 3.694 3.695 3.698 1.000
(time in wall seconds, less is better)
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.000 0.001 0.000 0.000 0.000 0.000 0.015 14.306
4,096 numpy 10,000 0.002 0.001 0.002 0.002 0.002 0.002 0.017 1.000
16,384 jax 10,000 0.000 0.001 0.000 0.000 0.000 0.000 0.015 67.323
16,384 numpy 1,000 0.009 0.002 0.007 0.008 0.008 0.008 0.022 1.000
65,536 jax 10,000 0.000 0.001 0.000 0.000 0.000 0.000 0.016 380.146
65,536 numpy 100 0.051 0.002 0.043 0.051 0.051 0.051 0.056 1.000
262,144 jax 1,000 0.000 0.001 0.000 0.000 0.000 0.000 0.015 1276.789
262,144 numpy 100 0.205 0.001 0.204 0.204 0.205 0.205 0.208 1.000
1,048,576 jax 100 0.000 0.000 0.000 0.000 0.000 0.000 0.002 2165.099
1,048,576 numpy 10 0.849 0.001 0.848 0.848 0.849 0.850 0.851 1.000
4,194,304 jax 100 0.001 0.000 0.001 0.001 0.001 0.001 0.001 3323.075
4,194,304 numpy 10 3.763 0.002 3.760 3.761 3.764 3.765 3.765 1.000
(time in wall seconds, less is better)
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 pytorch 10,000 0.000 0.001 0.000 0.000 0.000 0.000 0.013 15.127
4,096 numpy 10,000 0.002 0.001 0.002 0.002 0.002 0.002 0.015 1.000
16,384 pytorch 10,000 0.000 0.001 0.000 0.000 0.000 0.000 0.013 64.685
16,384 numpy 1,000 0.008 0.002 0.007 0.008 0.008 0.008 0.020 1.000
65,536 pytorch 1,000 0.000 0.001 0.000 0.000 0.000 0.000 0.013 359.766
65,536 numpy 100 0.047 0.002 0.045 0.046 0.046 0.047 0.053 1.000
262,144 pytorch 1,000 0.000 0.001 0.000 0.000 0.000 0.000 0.012 1067.722
262,144 numpy 100 0.200 0.002 0.187 0.199 0.200 0.201 0.206 1.000
1,048,576 pytorch 1,000 0.000 0.001 0.000 0.000 0.000 0.000 0.012 1673.047
1,048,576 numpy 10 0.772 0.001 0.771 0.771 0.772 0.772 0.774 1.000
4,194,304 pytorch 100 0.001 0.000 0.001 0.001 0.001 0.001 0.001 3173.256
4,194,304 numpy 10 3.690 0.002 3.687 3.688 3.690 3.691 3.693 1.000
(time in wall seconds, less is better)
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 tensorflow 10,000 0.000 0.001 0.000 0.000 0.000 0.000 0.015 4.807
4,096 numpy 10,000 0.002 0.001 0.002 0.002 0.002 0.002 0.017 1.000
16,384 tensorflow 10,000 0.000 0.001 0.000 0.000 0.000 0.000 0.016 22.519
16,384 numpy 1,000 0.008 0.001 0.007 0.008 0.008 0.008 0.022 1.000
65,536 tensorflow 10,000 0.000 0.001 0.000 0.000 0.000 0.000 0.016 113.232
65,536 numpy 100 0.042 0.005 0.035 0.039 0.039 0.040 0.052 1.000
262,144 tensorflow 1,000 0.000 0.000 0.000 0.000 0.000 0.000 0.012 555.981
262,144 numpy 100 0.199 0.003 0.196 0.197 0.198 0.198 0.207 1.000
1,048,576 tensorflow 1,000 0.001 0.001 0.001 0.001 0.001 0.001 0.015 1160.568
1,048,576 numpy 10 0.791 0.001 0.790 0.791 0.791 0.792 0.793 1.000
4,194,304 tensorflow 100 0.001 0.001 0.001 0.001 0.001 0.001 0.009 5053.383
4,194,304 numpy 10 3.752 0.002 3.749 3.750 3.751 3.754 3.756 1.000
(time in wall seconds, less is better)
A more balanced routine with many data dependencies (stencil operations), and tensor shapes of up to 5 dimensions. This is the most expensive part of Veros, so in a way this is the benchmark that interests me the most.
$ taskset -c 23 python run.py benchmarks/isoneutral_mixing/
benchmarks.isoneutral_mixing
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.003 3.617
4,096 numba 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.004 3.074
4,096 aesara 1,000 0.003 0.000 0.003 0.003 0.003 0.003 0.005 1.566
4,096 pytorch 1,000 0.004 0.000 0.004 0.004 0.004 0.004 0.006 1.061
4,096 numpy 1,000 0.004 0.000 0.004 0.004 0.004 0.004 0.007 1.000
16,384 jax 1,000 0.006 0.000 0.006 0.006 0.006 0.006 0.009 2.759
16,384 numba 1,000 0.007 0.000 0.007 0.007 0.007 0.007 0.009 2.348
16,384 pytorch 1,000 0.011 0.000 0.010 0.010 0.011 0.011 0.014 1.528
16,384 aesara 1,000 0.011 0.000 0.011 0.011 0.011 0.011 0.014 1.451
16,384 numpy 1,000 0.016 0.000 0.016 0.016 0.016 0.016 0.019 1.000
65,536 jax 100 0.028 0.000 0.028 0.028 0.028 0.028 0.030 2.238
65,536 numba 100 0.030 0.000 0.030 0.030 0.030 0.030 0.031 2.081
65,536 pytorch 100 0.037 0.000 0.037 0.037 0.037 0.037 0.038 1.702
65,536 aesara 100 0.047 0.000 0.047 0.047 0.047 0.047 0.049 1.352
65,536 numpy 100 0.063 0.000 0.063 0.063 0.063 0.063 0.066 1.000
262,144 numba 100 0.116 0.000 0.115 0.116 0.116 0.116 0.116 2.157
262,144 jax 100 0.122 0.000 0.122 0.122 0.122 0.122 0.124 2.047
262,144 pytorch 100 0.147 0.000 0.147 0.147 0.147 0.148 0.149 1.693
262,144 aesara 10 0.179 0.000 0.179 0.179 0.179 0.179 0.180 1.393
262,144 numpy 10 0.250 0.000 0.249 0.250 0.250 0.250 0.250 1.000
1,048,576 numba 10 0.516 0.004 0.512 0.512 0.519 0.519 0.520 2.221
1,048,576 jax 10 0.623 0.004 0.616 0.620 0.626 0.626 0.627 1.840
1,048,576 pytorch 10 0.751 0.002 0.747 0.752 0.752 0.752 0.752 1.527
1,048,576 aesara 10 0.851 0.001 0.850 0.850 0.851 0.851 0.852 1.348
1,048,576 numpy 10 1.147 0.002 1.142 1.147 1.147 1.148 1.148 1.000
4,194,304 numba 10 2.247 0.003 2.243 2.244 2.246 2.250 2.252 2.282
4,194,304 jax 10 2.569 0.003 2.563 2.569 2.570 2.571 2.573 1.995
4,194,304 aesara 10 3.773 0.002 3.769 3.772 3.774 3.776 3.776 1.359
4,194,304 pytorch 10 3.797 0.022 3.751 3.784 3.797 3.815 3.826 1.350
4,194,304 numpy 10 5.126 0.003 5.119 5.124 5.128 5.128 5.131 1.000
(time in wall seconds, less is better)
$ taskset -c 23 python run.py benchmarks/isoneutral_mixing/ -s 16777216
benchmarks.isoneutral_mixing
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
16,777,216 numba 10 9.239 0.098 9.042 9.178 9.295 9.301 9.327 2.614
16,777,216 jax 10 10.134 0.006 10.125 10.131 10.134 10.137 10.144 2.383
16,777,216 aesara 10 15.387 0.046 15.323 15.343 15.389 15.433 15.450 1.569
16,777,216 pytorch 10 17.916 0.024 17.856 17.910 17.925 17.931 17.939 1.348
16,777,216 numpy 10 24.148 0.018 24.117 24.137 24.147 24.161 24.179 1.000
(time in wall seconds, less is better)
$ for backend in cupy jax pytorch; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/isoneutral_mixing/ --gpu -b $backend -b numpy; done
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 1,000 0.004 0.000 0.004 0.004 0.004 0.004 0.008 1.000
4,096 cupy 1,000 0.011 0.000 0.010 0.010 0.011 0.011 0.014 0.401
16,384 cupy 1,000 0.011 0.000 0.010 0.011 0.011 0.011 0.014 1.519
16,384 numpy 1,000 0.016 0.000 0.016 0.016 0.016 0.016 0.020 1.000
65,536 cupy 100 0.011 0.000 0.011 0.011 0.011 0.011 0.011 5.851
65,536 numpy 100 0.063 0.001 0.063 0.063 0.063 0.063 0.071 1.000
262,144 cupy 100 0.011 0.000 0.011 0.011 0.011 0.011 0.013 24.889
262,144 numpy 10 0.273 0.005 0.257 0.274 0.274 0.275 0.276 1.000
1,048,576 cupy 10 0.021 0.000 0.021 0.021 0.021 0.022 0.022 56.671
1,048,576 numpy 10 1.211 0.004 1.208 1.209 1.209 1.211 1.222 1.000
4,194,304 cupy 10 0.080 0.001 0.079 0.079 0.079 0.081 0.082 64.295
4,194,304 numpy 10 5.133 0.003 5.127 5.133 5.134 5.135 5.138 1.000
(time in wall seconds, less is better)
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.005 8.233
4,096 numpy 1,000 0.005 0.000 0.005 0.005 0.005 0.005 0.009 1.000
16,384 jax 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.003 25.151
16,384 numpy 1,000 0.017 0.000 0.016 0.016 0.017 0.017 0.019 1.000
65,536 jax 100 0.001 0.000 0.001 0.001 0.001 0.001 0.001 56.285
65,536 numpy 100 0.065 0.001 0.063 0.065 0.065 0.065 0.068 1.000
262,144 jax 100 0.004 0.000 0.004 0.004 0.004 0.004 0.007 67.464
262,144 numpy 10 0.270 0.009 0.258 0.262 0.272 0.273 0.283 1.000
1,048,576 jax 10 0.015 0.000 0.015 0.015 0.015 0.015 0.015 81.204
1,048,576 numpy 10 1.241 0.001 1.239 1.240 1.241 1.242 1.242 1.000
4,194,304 jax 10 0.057 0.000 0.057 0.057 0.057 0.057 0.057 91.099
4,194,304 numpy 10 5.222 0.020 5.209 5.215 5.216 5.219 5.281 1.000
(time in wall seconds, less is better)
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 1,000 0.004 0.000 0.004 0.004 0.004 0.004 0.007 1.000
4,096 pytorch 1,000 0.005 0.000 0.005 0.005 0.005 0.005 0.008 0.827
16,384 pytorch 1,000 0.005 0.000 0.005 0.005 0.005 0.005 0.008 3.072
16,384 numpy 1,000 0.016 0.000 0.016 0.016 0.016 0.016 0.020 1.000
65,536 pytorch 100 0.006 0.000 0.006 0.006 0.006 0.006 0.006 11.134
65,536 numpy 100 0.063 0.001 0.063 0.063 0.063 0.063 0.066 1.000
262,144 pytorch 100 0.006 0.000 0.006 0.006 0.006 0.006 0.006 45.709
262,144 numpy 10 0.271 0.009 0.250 0.267 0.274 0.276 0.285 1.000
1,048,576 pytorch 10 0.016 0.001 0.016 0.016 0.016 0.016 0.018 74.789
1,048,576 numpy 10 1.208 0.001 1.207 1.207 1.208 1.209 1.209 1.000
4,194,304 pytorch 10 0.056 0.000 0.056 0.056 0.056 0.057 0.057 91.363
4,194,304 numpy 10 5.139 0.002 5.135 5.136 5.139 5.141 5.141 1.000
(time in wall seconds, less is better)
This routine consists of some stencil operations and some linear algebra (a tridiagonal matrix solver), which cannot be vectorized.
$ taskset -c 23 python run.py benchmarks/turbulent_kinetic_energy/
benchmarks.turbulent_kinetic_energy
===================================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.000 0.000 0.000 0.000 0.000 0.000 0.001 5.826
4,096 numba 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.002 1.838
4,096 pytorch 1,000 0.002 0.000 0.002 0.002 0.002 0.002 0.003 1.221
4,096 numpy 1,000 0.002 0.000 0.002 0.002 0.002 0.002 0.003 1.000
16,384 jax 1,000 0.002 0.000 0.002 0.002 0.002 0.002 0.003 3.914
16,384 pytorch 1,000 0.004 0.000 0.004 0.004 0.004 0.004 0.005 1.862
16,384 numba 1,000 0.004 0.000 0.004 0.004 0.004 0.004 0.005 1.802
16,384 numpy 1,000 0.008 0.000 0.008 0.008 0.008 0.008 0.011 1.000
65,536 jax 100 0.009 0.000 0.009 0.009 0.009 0.009 0.010 3.263
65,536 pytorch 100 0.013 0.000 0.013 0.013 0.013 0.013 0.013 2.209
65,536 numba 100 0.014 0.000 0.014 0.014 0.014 0.014 0.014 2.006
65,536 numpy 100 0.029 0.000 0.028 0.028 0.028 0.029 0.029 1.000
262,144 jax 100 0.042 0.000 0.042 0.042 0.042 0.042 0.043 2.569
262,144 numba 100 0.047 0.000 0.047 0.047 0.047 0.048 0.048 2.295
262,144 pytorch 100 0.050 0.000 0.050 0.050 0.050 0.050 0.051 2.171
262,144 numpy 10 0.109 0.000 0.108 0.109 0.109 0.109 0.110 1.000
1,048,576 numba 10 0.187 0.000 0.187 0.187 0.187 0.187 0.188 2.711
1,048,576 jax 10 0.237 0.000 0.237 0.237 0.237 0.237 0.238 2.140
1,048,576 pytorch 10 0.276 0.000 0.275 0.276 0.276 0.276 0.276 1.839
1,048,576 numpy 10 0.507 0.001 0.506 0.507 0.507 0.508 0.508 1.000
4,194,304 numba 10 0.689 0.002 0.686 0.687 0.690 0.691 0.693 3.022
4,194,304 jax 10 1.043 0.001 1.042 1.043 1.043 1.044 1.044 1.997
4,194,304 pytorch 10 1.314 0.003 1.310 1.312 1.314 1.317 1.318 1.585
4,194,304 numpy 10 2.084 0.003 2.079 2.080 2.085 2.086 2.088 1.000
(time in wall seconds, less is better)
$ taskset -c 23 python run.py benchmarks/turbulent_kinetic_energy/ -s 16777216
benchmarks.turbulent_kinetic_energy
===================================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
16,777,216 numba 10 2.997 0.005 2.991 2.994 2.996 2.999 3.007 3.616
16,777,216 jax 10 4.168 0.003 4.164 4.165 4.168 4.170 4.174 2.600
16,777,216 pytorch 10 6.270 0.009 6.249 6.266 6.270 6.277 6.282 1.729
16,777,216 numpy 10 10.839 0.011 10.823 10.829 10.836 10.850 10.853 1.000
(time in wall seconds, less is better)
$ for backend in jax pytorch; do CUDA_VISIBLE_DEVICES="0" python run.py benchmarks/turbulent_kinetic_energy/ --gpu -b $backend -b numpy; done
benchmarks.turbulent_kinetic_energy
===================================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.001 0.000 0.000 0.001 0.001 0.001 0.008 4.261
4,096 numpy 1,000 0.003 0.000 0.002 0.002 0.003 0.003 0.010 1.000
16,384 jax 1,000 0.001 0.001 0.001 0.001 0.001 0.001 0.008 11.252
16,384 numpy 1,000 0.008 0.000 0.008 0.008 0.008 0.008 0.013 1.000
65,536 jax 100 0.001 0.001 0.001 0.001 0.001 0.001 0.007 26.405
65,536 numpy 100 0.030 0.002 0.029 0.029 0.029 0.029 0.037 1.000
262,144 jax 100 0.003 0.001 0.003 0.003 0.003 0.003 0.006 46.720
262,144 numpy 10 0.135 0.007 0.115 0.134 0.136 0.137 0.143 1.000
1,048,576 jax 10 0.016 0.000 0.015 0.016 0.016 0.016 0.016 36.351
1,048,576 numpy 10 0.579 0.008 0.567 0.571 0.584 0.585 0.588 1.000
4,194,304 jax 10 0.039 0.000 0.039 0.039 0.039 0.039 0.039 55.896
4,194,304 numpy 10 2.190 0.032 2.149 2.152 2.212 2.218 2.220 1.000
(time in wall seconds, less is better)
benchmarks.turbulent_kinetic_energy
===================================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 1,000 0.002 0.000 0.002 0.002 0.002 0.002 0.005 1.000
4,096 pytorch 1,000 0.003 0.000 0.003 0.003 0.003 0.003 0.004 0.834
16,384 pytorch 1,000 0.003 0.000 0.003 0.003 0.003 0.003 0.005 2.455
16,384 numpy 1,000 0.008 0.000 0.008 0.008 0.008 0.008 0.008 1.000
65,536 pytorch 100 0.004 0.000 0.004 0.004 0.004 0.004 0.004 7.323
65,536 numpy 100 0.029 0.000 0.029 0.029 0.029 0.029 0.030 1.000
262,144 pytorch 100 0.005 0.000 0.005 0.005 0.005 0.005 0.005 23.533
262,144 numpy 10 0.111 0.000 0.110 0.111 0.111 0.111 0.111 1.000
1,048,576 pytorch 10 0.008 0.000 0.008 0.008 0.008 0.008 0.008 72.466
1,048,576 numpy 10 0.573 0.003 0.567 0.571 0.574 0.574 0.576 1.000
4,194,304 pytorch 10 0.029 0.000 0.029 0.029 0.030 0.030 0.030 73.957
4,194,304 numpy 10 2.175 0.002 2.172 2.174 2.175 2.176 2.177 1.000
(time in wall seconds, less is better)