Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory usage in vectorised Jacobian computation #454

Closed
matt-graham opened this issue Dec 5, 2018 · 2 comments
Closed

Memory usage in vectorised Jacobian computation #454

matt-graham opened this issue Dec 5, 2018 · 2 comments

Comments

@matt-graham
Copy link

I'm interested in using Autograd to compute Jacobians. As the inbuilt jacobian function can be slow in some cases due to (I think) the interpreter overhead in iteratively computing a vector Jacobian product for each each output dimension, I've been using a slightly hacky vectorised alternative.

Specifically I make sure all the functions which I need to compute Jacobians for broadcast over the leading dimensions of the input(s). So for example for a function func which as the base case takes an array u as argument with shape shape_u and returns an array y = func(u) with shape_y then if we instead input an array u with shape (n,) + shape_u we expect the returned y = func(u) to have shape (n,) + shape_y and that all([np.allclose(y[i], func(u[i])) for i in range(n)]) == True. For a function of this type then the below jacobian_and_value function, (adapted from the functions in the autograd.differential_operators module) will compute the Jacobian by tiling size_output copies of the input along the leading dimension where size_output is the size of the output array, then computes a single 'vector' Jacobian product using the identity matrix (or generalisation for output arrays with more than one dimension in the base case) as the 'vector'.

from autograd.wrap_util import unary_to_nary
from autograd.builtins import tuple as atuple
from autograd.core import make_vjp as _make_vjp
from autograd.extend import vspace
import autograd.numpy as np

@unary_to_nary
def jacobian_and_value(fun, x):
    val = fun(x)
    v_vspace = vspace(val)
    x_vspace = vspace(x)
    x_rep = np.tile(x, (v_vspace.size,) + (1,) * x_vspace.ndim)
    vjp_rep, _ = _make_vjp(fun, x_rep)
    jacobian_shape = v_vspace.shape + x_vspace.shape
    basis_vectors = np.array([b for b in v_vspace.standard_basis()])
    jacobian = vjp_rep(basis_vectors)
    return np.reshape(jacobian, jacobian_shape), val

In terms of computation time, this can give quite significant savings over the inbuilt jacobian function in some cases due to pushing more work in to NumPy array functions. However it has a couple of downsides

  • The memory usage is now of order [size of output] times higher than the inbuilt jacobian function as all the intermediate values computed in the the forward pass for each of the replicated inputs now need to be stored for the backwards pass.
  • Relatedly we now compute the forward pass for multiple identical inputs while we should only really require to compute the forward pass once. As this just gives a small constant factor overhead in the overall computation time for the Jacobian this is less of a problem.

I have found the increased memory usage to be particularly problematic when the function being differentiated involves a loop with lots of iterations (e.g. the timestepping loop in the numerical integration of a ODE / PDE system), as even the memory usage associated with tracing a single forward pass can be quite significant in this case.

I originally played around with trying to see if I could use the VJP function returned by make_vjp for evaluating the function on the original non-replicated input to do something similar, however (unsurprisingly) this didn't work as even if the primitive calls in the original forward function broadcast as required along the leading dimensions, the corresponding primitive VJP calls in the backward pass do not necessarily do so. I also tried adding an extra leading dimension of length one to the input passed to make_vjp however this didn't help (again not surprising). The main culprits for operations in the function causing issues when trying to broadcast in the backwards pass I've encountered so far are slicing arrays (and corresponding untake calls in the backwards pass) and calls to numpy.reshape.

It seems what I want to do is exactly what would be achieved by the tensor Jacobian products proposed in PR #280 by @mattjj however I am not sure if that is something actively being worked on anymore or not?

Otherwise I was wondering if anyone has any other ideas of possible ways for achieve something like this without the memory overhead?

@mattjj
Copy link
Contributor

mattjj commented Dec 10, 2018

Thanks for this detailed explanation! Those are interesting findings with your vectorized Jacobian implementation. Indeed vectorization seems really important here.

This isn't a full response to your issue, but I wanted to offer a quick note in case it's helpful: in JAX we think we cracked the "tensor-Jacobian products" problem in a more general way via the vmap transformation. We used it to implement both forward- and reverse-mode Jacobian functions (jacfwd and jacrev), which indeed only linearize once. Plus you can use JAX's jit transformation to end-to-end compile these functions and even run them on accelerator hardware like a GPU (if you have one). The jit compilation uses XLA for whole-program optimization, so even if you're just using the CPU backend, for many computations it should be much faster due to XLA's optimizations and removing all interpreter overhead (no Python interpreter, but moreover not even any graph interpreter at all).

JAX doesn't have as much NumPy coverage as Autograd yet, and so it might not cover your use case, but we're working hard on it and expect it to help with use cases like this one.

There are some other tricks we might be able to deploy to reduce memory usage, depending on your specific computation, including a recent trick developed by some Julia folks: https://arxiv.org/pdf/1810.08297.pdf.

@matt-graham
Copy link
Author

Thanks @mattjj for your comments and suggestions!

JAX looks really interesting - being able to write the model functions without explicitly worrying about broadcasting / vectorisation would be ideal, as would being able to compile functions to avoid interpreter overhead. I had a go at playing around with using the jacfwd and jacrev functions on some simple instances of the sort of models I've been looking at: it seems that there are a few primitives lacking batched versions that prevent me being able to use JAX for my use cases at the moment, but it looks otherwise like it would be a great fit. I will have a look at the JAX batching implementation to see if I can submit any PRs there for the primitives for my use cases or at least create issues to document what is not currently covered.

I will close the issue here as it seems like going forward JAX will probably be able to resolve the issues I have, and for now my hacky work around will suffice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants