Skip to content

Commit

Permalink
Add norm and true_divide Cupy grads
Browse files Browse the repository at this point in the history
  • Loading branch information
bartvm committed Jul 31, 2017
1 parent df80463 commit 48c7096
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
2 changes: 2 additions & 0 deletions autograd/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
from . import cupy_wrapper
from . import cupy_grads
from . import cupy_extra
from . import random
from . import linalg
from .cupy_wrapper import *
2 changes: 2 additions & 0 deletions autograd/cupy/cupy_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
acp.subtract.defvjp(lambda g, ans, vs, gvs, x, y: unbroadcast(vs, gvs, -g), argnum=1)
acp.divide.defvjp( lambda g, ans, vs, gvs, x, y: unbroadcast(vs, gvs, g / y))
acp.divide.defvjp( lambda g, ans, vs, gvs, x, y: unbroadcast(vs, gvs, - g * x / y**2), argnum=1)
acp.true_divide.defvjp( lambda g, ans, vs, gvs, x, y: unbroadcast(vs, gvs, g / y))
acp.true_divide.defvjp( lambda g, ans, vs, gvs, x, y: unbroadcast(vs, gvs, - g * x / y**2), argnum=1)
acp.maximum.defvjp( lambda g, ans, vs, gvs, x, y: unbroadcast(vs, gvs, g * balanced_eq(x, ans, y)))
acp.maximum.defvjp( lambda g, ans, vs, gvs, x, y: unbroadcast(vs, gvs, g * balanced_eq(y, ans, x)),
argnum=1)
Expand Down
62 changes: 62 additions & 0 deletions autograd/cupy/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import absolute_import
import cupy.linalg as cpla
from .cupy_wrapper import wrap_namespace
from . import cupy_wrapper as acp

wrap_namespace(cpla.__dict__, globals())


def grad_norm(g, ans, vs, gvs, x, ord=None, axis=None):
def check_implemented():
matrix_norm = (x.ndim == 2 and axis is None) or isinstance(axis, tuple)

if matrix_norm:
if not (ord is None or ord == 'fro' or ord == 'nuc'):
raise NotImplementedError('Gradient of matrix norm not '
'implemented for ord={}'.format(ord))
elif not (ord is None or ord > 1):
raise NotImplementedError('Gradient of norm not '
'implemented for ord={}'.format(ord))

if axis is None:
expand = lambda a: a
elif isinstance(axis, tuple):
row_axis, col_axis = axis
if row_axis > col_axis:
row_axis = row_axis - 1
expand = lambda a: acp.expand_dims(acp.expand_dims(a,
row_axis), col_axis)
else:
expand = lambda a: acp.expand_dims(a, axis=axis)

if ord == 'nuc':
if axis is None:
roll = lambda a: a
unroll = lambda a: a
else:
row_axis, col_axis = axis
if row_axis > col_axis:
row_axis = row_axis - 1
# Roll matrix axes to the back
roll = lambda a: acp.rollaxis(acp.rollaxis(a, col_axis, a.ndim),
row_axis, a.ndim-1)
# Roll matrix axes to their original position
unroll = lambda a: acp.rollaxis(acp.rollaxis(a, a.ndim-2, row_axis),
a.ndim-1, col_axis)

check_implemented()
if ord is None or ord == 2 or ord is 'fro':
return expand(g / ans) * x
elif ord == 'nuc':
dot = acp.dot if x.ndim == 2 else partial(acp.einsum, '...ij,...jk->...ik')
x_rolled = roll(x)
u, s, vt = svd(x_rolled, full_matrices=False)
uvt_rolled = dot(u, vt)
# Roll the matrix axes back to their correct positions
uvt = unroll(uvt_rolled)
g = expand(g)
return g * uvt
else:
# see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
return expand(g / ans**(ord-1)) * x * acp.abs(x)**(ord-2)
norm.defvjp(grad_norm)

0 comments on commit 48c7096

Please sign in to comment.