From 8149360a72c60ff14c1197e1fb33acfe73edf1ac Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 26 Aug 2017 17:43:10 -0700 Subject: [PATCH] make dot tjp drop references as possible --- autograd/numpy/numpy_tjps.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/autograd/numpy/numpy_tjps.py b/autograd/numpy/numpy_tjps.py index 97174681..e0b6f80d 100644 --- a/autograd/numpy/numpy_tjps.py +++ b/autograd/numpy/numpy_tjps.py @@ -88,23 +88,24 @@ # ----- Trickier grads ----- def tjp_dot_arg0(ans, vs, out_vs, A, B): - if anp.ndim(B) == 0 or anp.ndim(B) == 1 or anp.ndim(A) == 0: - contract_dims = max(0, anp.ndim(B) - (anp.ndim(A) != 0)) - return lambda G: anp.tensordot(G, B, contract_dims) + A_ndim, B_ndim = vs.ndim, anp.ndim(B) + if B_ndim == 0 or B_ndim == 1 or A_ndim == 0: + contract_num = max(0, B_ndim - (A_ndim != 0)) + return lambda G: anp.tensordot(G, B, contract_num) else: - return lambda G: anp.tensordot(G, anp.swapaxes(B, -1, -2), anp.ndim(B) - 1) + return lambda G: anp.tensordot(G, anp.swapaxes(B, -1, -2), B_ndim - 1) deftjp(anp.dot, tjp_dot_arg0) def tjp_dot_arg1(ans, vs, out_vs, A, B): - needs_transpose = anp.ndim(B) > 1 and anp.ndim(A) != 0 + A_ndim, B_ndim = anp.ndim(A), vs.ndim + needs_transpose = B_ndim > 1 and A_ndim != 0 swap = (lambda x: anp.swapaxes(x, -1, -2)) if needs_transpose else (lambda x: x) - if anp.ndim(A) == 0 or anp.ndim(A) == 1 or anp.ndim(B) == 0: - contract_dims = max(0, anp.ndim(A) - (anp.ndim(B) != 0)) - return lambda G: swap(anp.tensordot(G, A, contract_dims)) + if A_ndim == 0 or A_ndim == 1 or B_ndim == 0: + contract_num = max(0, A_ndim - (B_ndim != 0)) + return lambda G: swap(anp.tensordot(G, A, contract_num)) else: return lambda G: swap(anp.tensordot( - G, A, [range(-anp.ndim(A) - anp.ndim(B) + 2, -anp.ndim(B) + 1), - range(anp.ndim(A) - 1)])) + G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)])) deftjp(anp.dot, tjp_dot_arg1, argnum=1) def tjp_transpose(ans, in_vs, out_vs, x, axes=None):