How to Implement Meshgrid in PyTorch

In this post, I’ll show how to implement meshgrid in PyTorch.

The following graph shows what a meshgrid would be in numpy:

Image credit: https://www.python-course.eu/matplotlib_contour_plot.php

If we have two tensors x and y:

>>> x = torch.from_numpy(np.array([1, 2, 3, 4]))
>>> y = torch.from_numpy(np.array([11, 22, 33, 44]))

We’d like a tensor z with shape [4, 4, 2] in which z[i, j] is the concatenation of x[i] and y[j].

>>> xx = x.view(-1, 1).repeat(1, 4)
>>> xx

 1  1  1  1
 2  2  2  2
 3  3  3  3
 4  4  4  4
[torch.LongTensor of size 4x4]

>>> yy = y.repeat(4,1)
>>> yy

 11  22  33  44
 11  22  33  44
 11  22  33  44
 11  22  33  44
[torch.LongTensor of size 4x4]

Now we concatenate xx and yy by axis 2 after expanding a new dimension for xx and yy.

>>> meshed = torch.cat([xx.unsqueeze_(2),yy.unsqueeze_(2)], 2)
>>> meshed.size()
torch.Size([4, 4, 2, 1])

Let’s verify the results:

>>> meshed[1,2]

  2
 33
[torch.LongTensor of size 2x1]

>>> x[1]
2
>>> y[2]
33

Related

comments powered by Disqus