diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 5bb3c1461671b..dfd8e238515c5 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -763,6 +763,16 @@ def call_back(): if v.requires_grad and v.grad is None: v.grad = torch.zeros_like(v) + if v.requires_grad: + if not isinstance(v.grad, torch.Tensor): + raise ValueError( + f"Expecting torch.Tensor for gradient tensor, but getting {v.grad.__class__.__name__} instead" + ) + if not v.grad.is_contiguous(): + raise ValueError( + "Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into taichi kernel." + ) + tmp = v if (str(v.device) != "cpu") and not ( str(v.device).startswith("cuda") and taichi_arch == _ti_core.Arch.cuda