From 6a30b8dccd759d7943e75577a983a8df773a15c4 Mon Sep 17 00:00:00 2001 From: Bob Cao Date: Wed, 27 Dec 2023 00:13:37 -0800 Subject: [PATCH] [lang] Warn about non-contiguous gradient tensors (#8450) Fixes: #8443 ### Brief Summary copilot:summary ### Walkthrough copilot:walkthrough --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/lang/kernel_impl.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 5bb3c1461671b..dfd8e238515c5 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -763,6 +763,16 @@ def call_back(): if v.requires_grad and v.grad is None: v.grad = torch.zeros_like(v) + if v.requires_grad: + if not isinstance(v.grad, torch.Tensor): + raise ValueError( + f"Expecting torch.Tensor for gradient tensor, but getting {v.grad.__class__.__name__} instead" + ) + if not v.grad.is_contiguous(): + raise ValueError( + "Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into taichi kernel." + ) + tmp = v if (str(v.device) != "cpu") and not ( str(v.device).startswith("cuda") and taichi_arch == _ti_core.Arch.cuda