You can use the take() function with indices of elements as an argument. It will return a tensor whose shape will be the same as the indices. The take() function treats the input tensor as a 1D tensor.
Here are an examples:
When indices are 1D tensor
>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.3410, -2.3171, 0.2685, -1.4083],
[-0.1782, 0.4501, 0.4013, -0.4777],
[-0.8800, -0.8078, -1.0272, 0.0961],
[-1.2799, -0.5404, -1.3871, -1.5463],
[-0.3515, -0.0466, -1.5026, 0.6122],
[ 0.7668, -1.1009, -0.5753, -0.0123]])
>>> i=torch.tensor([1, 5, 6, 8])
>>> torch.take(a,i)
tensor([-2.3171, 0.4501, 0.4013, -0.8800])
When indices are 2D tensor
>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.3410, -2.3171, 0.2685, -1.4083],
[-0.1782, 0.4501, 0.4013, -0.4777],
[-0.8800, -0.8078, -1.0272, 0.0961],
[-1.2799, -0.5404, -1.3871, -1.5463],
[-0.3515, -0.0466, -1.5026, 0.6122],
[ 0.7668, -1.1009, -0.5753, -0.0123]])
>>> i=torch.tensor([[1,2],[3,4]])
>>> torch.take(a,i)
tensor([[-2.3171, 0.2685],
[-1.4083, -0.1782]])