diff --git a/autograd/numpy/__init__.py b/autograd/numpy/__init__.py index 0858819f..fd929697 100644 --- a/autograd/numpy/__init__.py +++ b/autograd/numpy/__init__.py @@ -4,6 +4,7 @@ from . import numpy_vspaces from . import numpy_vjps from . import numpy_jvps +from . import numpy_tjps from . import linalg from . import fft from . import random diff --git a/autograd/numpy/numpy_tjps.py b/autograd/numpy/numpy_tjps.py new file mode 100644 index 00000000..e0b6f80d --- /dev/null +++ b/autograd/numpy/numpy_tjps.py @@ -0,0 +1,133 @@ +from __future__ import absolute_import +import numpy as onp +from functools import partial +from ..util import func # TODO(mattjj): should this import use autograd.util, not ..util? +from autograd.tracer import primitive, getval +from autograd.vspace import vspace +from autograd.core import SparseObject +from autograd.tjp import deftjp, vjps_are_tjps +from . import numpy_wrapper as anp +from .numpy_boxes import ArrayBox + +# ----- Binary ufuncs ----- + +# The only difference here is we have to use a modified unbroadcast function, +# which handles leading dimensions (if they exist). Otherwise, the expressions +# used in the VJPs already broadcast along leading dimensions of g. + +deftjp(anp.add, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g)) +deftjp(anp.add, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g), argnum=1) +deftjp(anp.multiply, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, y * g)) +deftjp(anp.multiply, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, x * g), argnum=1) +deftjp(anp.subtract, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g)) +deftjp(anp.subtract, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, -g), argnum=1) +deftjp(anp.divide, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g / y)) +deftjp(anp.divide, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, - g * x / y**2), argnum=1) +deftjp(anp.maximum, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(x, ans, y))) +deftjp(anp.maximum, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(y, ans, x)), argnum=1) +deftjp(anp.minimum, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(x, ans, y))) +deftjp(anp.minimum, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(y, ans, x)), argnum=1) +deftjp(anp.fmax, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(x, ans, y))) +deftjp(anp.fmax, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(y, ans, x)), argnum=1) +deftjp(anp.fmin, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(x, ans, y))) +deftjp(anp.fmin, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(y, ans, x)), argnum=1) +deftjp(anp.logaddexp, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * anp.exp(x-ans))) +deftjp(anp.logaddexp, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * anp.exp(y-ans)), argnum=1) +deftjp(anp.logaddexp2, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * 2**(x-ans))) +deftjp(anp.logaddexp2, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * 2**(y-ans)), argnum=1) +deftjp(anp.true_divide, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g / y)) +deftjp(anp.true_divide, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, - g * x / y**2), argnum=1) +deftjp(anp.mod, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g)) +deftjp(anp.remainder, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g)) +deftjp(anp.mod, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, -g * anp.floor(x/y)), argnum=1) +deftjp(anp.remainder, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, -g * anp.floor(x/y)), argnum=1) +deftjp(anp.power, + lambda ans, vs, out_vs, x, y : lambda g: + unbroadcast(vs, out_vs, g * y * x ** anp.where(y, y - 1, 1.))) +deftjp(anp.power, + lambda ans, vs, out_vs, x, y : lambda g: + unbroadcast(vs, out_vs, g * anp.log(replace_zero(x, 1.)) * x ** y), argnum=1) + +# ----- Simple grads ----- + +# Some VJP implementations already broadcast along leading dimensions of g, so +# they work as TJP definitions too. We use the vjps_are_tjps function for that. + +vjps_are_tjps(anp.absolute) +vjps_are_tjps(anp.reciprocal) +vjps_are_tjps(anp.exp) +vjps_are_tjps(anp.exp2) +vjps_are_tjps(anp.expm1) +vjps_are_tjps(anp.log) +vjps_are_tjps(anp.log2) +vjps_are_tjps(anp.log10) +vjps_are_tjps(anp.log1p) +vjps_are_tjps(anp.sin) +vjps_are_tjps(anp.cos) +vjps_are_tjps(anp.tan) +vjps_are_tjps(anp.arcsin) +vjps_are_tjps(anp.arccos) +vjps_are_tjps(anp.arctan) +vjps_are_tjps(anp.sinh) +vjps_are_tjps(anp.cosh) +vjps_are_tjps(anp.tanh) +vjps_are_tjps(anp.arcsinh) +vjps_are_tjps(anp.arccosh) +vjps_are_tjps(anp.arctanh) +vjps_are_tjps(anp.rad2deg) +vjps_are_tjps(anp.degrees) +vjps_are_tjps(anp.deg2rad) +vjps_are_tjps(anp.radians) +vjps_are_tjps(anp.square) +vjps_are_tjps(anp.sqrt) +vjps_are_tjps(anp.sinc) + +vjps_are_tjps(anp.conj) +vjps_are_tjps(anp.conjugate) + +# ----- Trickier grads ----- + +def tjp_dot_arg0(ans, vs, out_vs, A, B): + A_ndim, B_ndim = vs.ndim, anp.ndim(B) + if B_ndim == 0 or B_ndim == 1 or A_ndim == 0: + contract_num = max(0, B_ndim - (A_ndim != 0)) + return lambda G: anp.tensordot(G, B, contract_num) + else: + return lambda G: anp.tensordot(G, anp.swapaxes(B, -1, -2), B_ndim - 1) +deftjp(anp.dot, tjp_dot_arg0) + +def tjp_dot_arg1(ans, vs, out_vs, A, B): + A_ndim, B_ndim = anp.ndim(A), vs.ndim + needs_transpose = B_ndim > 1 and A_ndim != 0 + swap = (lambda x: anp.swapaxes(x, -1, -2)) if needs_transpose else (lambda x: x) + if A_ndim == 0 or A_ndim == 1 or B_ndim == 0: + contract_num = max(0, A_ndim - (B_ndim != 0)) + return lambda G: swap(anp.tensordot(G, A, contract_num)) + else: + return lambda G: swap(anp.tensordot( + G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)])) +deftjp(anp.dot, tjp_dot_arg1, argnum=1) + +def tjp_transpose(ans, in_vs, out_vs, x, axes=None): + axes = tuple(reversed(range(in_vs.ndim))) if axes is None else anp.argsort(axes) + return lambda g: anp.transpose(g, tuple(range(anp.ndim(g) - len(axes))) + axes) +deftjp(anp.transpose, tjp_transpose) + +# ----- Utility functions ----- + +def unbroadcast(vs, out_vs, result): + result_vs = vspace(result) + leading_dims = result_vs.ndim - out_vs.ndim + broadcast_idx = leading_dims + while anp.ndim(result) > leading_dims + vs.ndim: + result = anp.sum(result, axis=broadcast_idx) + for axis, size in enumerate(vs.shape): + if size == 1: + result = anp.sum(result, axis=leading_dims + axis, keepdims=True) + if result_vs.iscomplex and not vs.iscomplex: + result = anp.real(result) + return result + +# ----- Extra functions used internally ----- + +# TODO untake diff --git a/autograd/numpy/numpy_vspaces.py b/autograd/numpy/numpy_vspaces.py index 739ed5ca..36618e6e 100644 --- a/autograd/numpy/numpy_vspaces.py +++ b/autograd/numpy/numpy_vspaces.py @@ -8,7 +8,7 @@ def __init__(self, value): self.dtype = value.dtype @property - def size(self): return np.prod(self.shape) + def size(self): return int(np.prod(self.shape)) @property def ndim(self): return len(self.shape) def zeros(self): return np.zeros(self.shape, dtype=self.dtype) @@ -26,6 +26,22 @@ def randn(self): def _inner_prod(self, x, y): return np.dot(np.ravel(x), np.ravel(y)) + def _product(self, other_vspace): + return self._contract(other_vspace, ndim=0) + + def _contract(self, other_vspace, ndim=None): + ndim = other_vspace.ndim if ndim is None else ndim + if not self.shape[-ndim % self.ndim:] == other_vspace.shape[:ndim]: + raise ValueError + + result = self.__new__(self.__class__) + result.shape = self.shape[:-ndim % self.ndim] + other_vspace.shape[ndim:] + result.dtype = np.promote_types(self.dtype, other_vspace.dtype) + return result + + def _kronecker_tensor(self): + return np.reshape(np.eye(self.size), self.shape + self.shape) + class ComplexArrayVSpace(ArrayVSpace): iscomplex = True diff --git a/autograd/tjp.py b/autograd/tjp.py new file mode 100644 index 00000000..37fb9a71 --- /dev/null +++ b/autograd/tjp.py @@ -0,0 +1,69 @@ +from collections import defaultdict +from .tracer import trace, primitive, Node, toposort +from .vspace import vspace +from .core import add_outgrads, primitive_vjps + +def make_tjp(fun, x): + start_node = TJPNode.new_root(x) + end_value, end_node = trace(start_node, fun, x) + if end_node is None: + in_vs, out_vs = start_node.vspace, vspace(end_value) + def tjp(G): return vspace(G)._contract(end_vs)._product(in_vs).zeros() + else: + def tjp(G): return tjp_backward_pass(G, end_node) + return tjp, end_value + +def tjp_backward_pass(G, end_node): + assert_vspace_compatible(G, end_node.vspace) + outgrads = {end_node : (G, False)} + for node in toposort(end_node): + cur_outgrad = outgrads.pop(node) + for parent, tjp in node.parents_and_tjps: + outgrad = tjp(cur_outgrad[0]) + assert_vspace_compatible(outgrad, parent.vspace) + outgrads[parent] = add_outgrads(vspace(outgrad), outgrads.get(parent), outgrad) + return cur_outgrad[0] + +class TJPNode(Node): + __slots__ = ['vspace', 'parents', 'parents_and_tjps'] + def __init__(self, value, fun, args, kwargs, parent_argnums, parents): + self.vspace = vspace(value) + self.parents = parents + self.parents_and_tjps = [ + (parent, primitive_tjp(fun, argnum, value, parent.vspace, + self.vspace, args, kwargs)) + for argnum, parent in zip(parent_argnums, parents)] + + def initialize_root(self, value): + self.vspace = vspace(value) + self.parents = [] + self.parents_and_tjps = [] + +primitive_tjps = defaultdict(dict) + +def primitive_tjp(fun, argnum, ans, in_vs, out_vs, args, kwargs): + return primitive_tjps[fun][argnum](ans, in_vs, out_vs, args, kwargs) + +def deftjp(fun, tjpmaker, argnum=0): + def tjp_fixed_args(ans, vs, gvs, args, kwargs): + return tjpmaker(ans, vs, gvs, *args, **kwargs) + primitive_tjps[fun][argnum] = tjp_fixed_args + +def deftjps(fun, tjpmaker, argnums): + for argnum in argnums: + deftjp(fun, partial(tjpmaker, argnum), argnum) + +def vjps_are_tjps(fun): + primitive_tjps[fun] = primitive_vjps[fun] + +def assert_vspace_compatible(x, vs): + assert vs.ndim == 0 or vspace(x).shape[-vs.ndim:] == vs.shape + +# convenience-wrapper stuff + +from .util import unary_to_nary + +@unary_to_nary +def jacobian(fun, x): + tjp, ans = make_tjp(fun, x) + return tjp(vspace(ans)._kronecker_tensor()) diff --git a/tests/test_tjps.py b/tests/test_tjps.py new file mode 100644 index 00000000..9abcfd97 --- /dev/null +++ b/tests/test_tjps.py @@ -0,0 +1,25 @@ +import autograd.numpy as np +import autograd.numpy.random as npr +from autograd.tjp import jacobian +from autograd import jacobian as _jacobian + +from itertools import product + + +def allclose(x, y): return x.shape == y.shape and np.allclose(x, y) + +def test_dot(): + npr.seed(0) + shapes = [(), (2,), (2, 2), (2, 2, 2)] + array_pairs = [(npr.normal(size=s1), npr.normal(size=s2)) + for s1, s2 in product(shapes, shapes)] + argnums = [0, 1] + + def check(A, B, argnum): + res1 = jacobian(np.dot, argnum)(A, B) + res2 = _jacobian(np.dot, argnum)(A, B) + assert allclose(res1, res2) + + for A, B in array_pairs: + for argnum in argnums: + yield check, A, B, argnum