diff --git a/autograd/numpy/numpy_tjps.py b/autograd/numpy/numpy_tjps.py index 97174681..bdf965b9 100644 --- a/autograd/numpy/numpy_tjps.py +++ b/autograd/numpy/numpy_tjps.py @@ -88,23 +88,25 @@ # ----- Trickier grads ----- def tjp_dot_arg0(ans, vs, out_vs, A, B): - if anp.ndim(B) == 0 or anp.ndim(B) == 1 or anp.ndim(A) == 0: - contract_dims = max(0, anp.ndim(B) - (anp.ndim(A) != 0)) - return lambda G: anp.tensordot(G, B, contract_dims) + A_ndim, B_ndim = vs.ndim, anp.ndim(B) + if B_ndim == 0 or B_ndim == 1 or A_ndim == 0: + contract_num = max(0, B_ndim - (A_ndim != 0)) + return lambda G: anp.tensordot(G, B, contract_num) else: - return lambda G: anp.tensordot(G, anp.swapaxes(B, -1, -2), anp.ndim(B) - 1) + return lambda G: anp.tensordot(G, anp.swapaxes(B, -1, -2), B_ndim - 1) deftjp(anp.dot, tjp_dot_arg0) def tjp_dot_arg1(ans, vs, out_vs, A, B): - needs_transpose = anp.ndim(B) > 1 and anp.ndim(A) != 0 + A_ndim, B_ndim = anp.ndim(A), vs.ndim + needs_transpose = B_ndim > 1 and A_ndim != 0 swap = (lambda x: anp.swapaxes(x, -1, -2)) if needs_transpose else (lambda x: x) - if anp.ndim(A) == 0 or anp.ndim(A) == 1 or anp.ndim(B) == 0: - contract_dims = max(0, anp.ndim(A) - (anp.ndim(B) != 0)) + if A_ndim == 0 or A_ndim == 1 or B_ndim == 0: + contract_dims = max(0, A_ndim - (B_ndim != 0)) return lambda G: swap(anp.tensordot(G, A, contract_dims)) else: return lambda G: swap(anp.tensordot( - G, A, [range(-anp.ndim(A) - anp.ndim(B) + 2, -anp.ndim(B) + 1), - range(anp.ndim(A) - 1)])) + G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), + range(A_ndim - 1)])) deftjp(anp.dot, tjp_dot_arg1, argnum=1) def tjp_transpose(ans, in_vs, out_vs, x, axes=None):