Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement grid_sample for sampling from a spatial tensor at arbitrary locations, similar to PyTorch #2674

Open
timstr opened this issue Jan 8, 2025 · 7 comments
Labels
feature The feature request

Comments

@timstr
Copy link

timstr commented Jan 8, 2025

Feature description

In brief, I propose implementing PyTorch's grid_sample method in Burn. This method allows one to pass a tensor of points in 1D, 2D, or 3D Euclidean space and look up and interpolate values from a tensor with 1, 2, or 3 dimensions (plus features + minibatch). This differs from the existing interpolate method because the sample locations need not be a regular grid. This differs from the existing gather method because it uses floating point locations instead of indices, interpolates, and its autodiff computes gradients for the sampling locations themselves as well.

Feature motivation

grid_sample is a powerful low-level primitive, closely related to the existing methods for interpolating and gathering, but with more freedom to sample values at arbitrary spatial locations and with gradient information for describing changes in the sample locations in addition to the grid values. Its possible use cases go beyond neural networks and it's well suited for more general-purpose (differentiable) tensor computation.

The tensor being sampled could be an image and the sample locations could be from a non-linear projection, such as cartesian-to-polar. The tensor being sampled could be a terrain heightmap, and the sample locations and their gradients could be used to model rainfall and erosion. Or the tensor being sampled could be a volumetric scene, and sample locations could be particles in a simulation.

I've personally used PyTorch's grid_sample during my thesis research project to, among other things, implement a basic but GPU parallelized renderer, implementing sphere tracing with a signed distance field represented as a volumetric tensor (link to research, 3D figures were generated this way). I would love to be able to do similar work in Rust using Burn.

@laggui laggui added the feature The feature request label Jan 9, 2025
@timstr
Copy link
Author

timstr commented Jan 15, 2025

I'm taking a stab at implementing this over at https://github.com/timstr/burn/tree/feat/grid_sample, so far I seem to have a working implementation of a grid_sample_1d op. I'll hopefully have 2D and 3D grid sampling as well as autodiff working soon and will put up a pull request once I get there.

I'm quite new to Burn and it's quite a large project, so any early feedback would be very welcome!
timstr@6ebfe29

@timstr
Copy link
Author

timstr commented Jan 15, 2025

One thing I would like to support is different padding modes for when indices extend beyond the data, such as zero-padding, repeat-padding, and mirror-padding. In principle this choice would be well captured by a simple enum and would be equally valid for other convolution and padding operations in Burn which AFAICT don't yet support a choice of padding method like this. At least, I would expect to see it in conv2d or ConvOptions but I don't see anything like zero/repeat/mirror/reflect mentioned

@dfsfdfse
Copy link

I have the same requirement here, and I tried using the slice method, but this causes my grid to keep cloning, making the loop very slow.

@laggui
Copy link
Member

laggui commented Jan 16, 2025

One thing I would like to support is different padding modes for when indices extend beyond the data, such as zero-padding, repeat-padding, and mirror-padding. In principle this choice would be well captured by a simple enum and would be equally valid for other convolution and padding operations in Burn which AFAICT don't yet support a choice of padding method like this. At least, I would expect to see it in conv2d or ConvOptions but I don't see anything like zero/repeat/mirror/reflect mentioned

For convolutions, anything other than zero padding is just an explicit padding operation that is applied before the convolution.

For padding, we only have tensor.pad(padding, value) right now which allows you to pad a tensor with an explicit value on the left, right, top, and bottom. This could be extended to other modes.

@laggui
Copy link
Member

laggui commented Jan 16, 2025

I'm taking a stab at implementing this over at https://github.com/timstr/burn/tree/feat/grid_sample, so far I seem to have a working implementation of a grid_sample_1d op. I'll hopefully have 2D and 3D grid sampling as well as autodiff working soon and will put up a pull request once I get there.

I'm quite new to Burn and it's quite a large project, so any early feedback would be very welcome! timstr@6ebfe29

Looks like you're on the right track!

Have you looked at the contributor book? There are details about adding a new tensor op, maybe that could be helpful.

@dfsfdfse
Copy link

I have implemented a custom backend that supports 4-dimensional padding with zero bilinear interpolation for grid sampling. Would you mind taking a look to see if there are any other improvements that could be made? I tested it on my 4070s, and it takes 5 microseconds. Burn and CubeCL are truly impressive.
grid-sample

@laggui
Copy link
Member

laggui commented Jan 22, 2025

I have implemented a custom backend that supports 4-dimensional padding with zero bilinear interpolation for grid sampling. Would you mind taking a look to see if there are any other improvements that could be made? I tested it on my 4070s, and it takes 5 microseconds. Burn and CubeCL are truly impressive. grid-sample

Pretty cool comparison! There must be some improvements to be made, I'd expect similar or better performance for the cubecl kernel. Tagging @nathanielsimard for visibility but not sure we actually have time to review external stuff 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature The feature request
Projects
None yet
Development

No branches or pull requests

3 participants