From 9725fa6fe128bd75741075ea1509ab6eeb85316b Mon Sep 17 00:00:00 2001 From: Kevin Swersky Date: Tue, 14 Mar 2017 22:53:37 -0700 Subject: [PATCH] Fixed error in grad_chooser (for e.g., max) when dtype is not numpy float64. --- autograd/numpy/numpy_grads.py | 6 +++++- autograd/util.py | 2 ++ tests/test_numpy.py | 10 ++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/autograd/numpy/numpy_grads.py b/autograd/numpy/numpy_grads.py index 0fc231b9b..b735bec1a 100644 --- a/autograd/numpy/numpy_grads.py +++ b/autograd/numpy/numpy_grads.py @@ -249,8 +249,12 @@ def grad_chooser(g, ans, vs, gvs, x, axis=None, keepdims=None): """Builds gradient of functions that choose a single item, such as min or max.""" g_repeated, _ = repeat_to_match_shape(g, vs, axis, keepdims) argmax_locations = x == repeat_to_match_shape(ans, vs, axis, keepdims)[0] + if onp.isscalar(x.value): + dt = onp.array(x.value).dtype + else: + dt = x.dtype return g_repeated * argmax_locations \ - / onp.sum(argmax_locations, axis=axis, keepdims=True) + / onp.sum(argmax_locations, axis=axis, keepdims=True).astype(dt) anp.max.defvjp(grad_chooser) anp.min.defvjp(grad_chooser) diff --git a/autograd/util.py b/autograd/util.py index a56facd6b..eaa05b30d 100644 --- a/autograd/util.py +++ b/autograd/util.py @@ -22,6 +22,8 @@ def unary_nd(f, x, eps=EPS): vs = vspace(x) nd_grad = np.zeros(vs.size) x_flat = vs.flatten(x) + if x_flat.dtype != np.float64: + nd_grad = nd_grad.astype(x_flat.dtype) for d in range(vs.size): dx = np.zeros(vs.size) dx[d] = eps/2 diff --git a/tests/test_numpy.py b/tests/test_numpy.py index 86c1a739f..c75e582ef 100644 --- a/tests/test_numpy.py +++ b/tests/test_numpy.py @@ -81,6 +81,16 @@ def fun(x): return to_scalar(np.max(x)) check_grads(fun, mat) check_grads(d_fun, mat) +def test_max_dtype(): + """Tests that a dtype other than float64 does not throw an error + with the gradient of max. + """ + def fun(x): return to_scalar(np.max(x, 1)) + d_fun = lambda x : to_scalar(grad(fun)(x)) + mat = npr.randn(10, 11).astype(np.float32) + check_grads(fun, mat) + check_grads(d_fun, mat) + def test_max_axis(): def fun(x): return to_scalar(np.max(x, axis=1)) d_fun = lambda x : to_scalar(grad(fun)(x))