+2 votes

Best answer

You can use the **torch.nonzero()** function. It returns a tensor containing the indices of all non-zero elements of a given input tensor. The syntax of the function is as follows:

torch.nonzero(input, *, out=None, as_tuple=False)

The function returns [row,col] pair for all non-zero elements.

Here is an example:

>>> import numpy as np

>>> import torch

>>> a=torch.tensor([np.random.randint(0,4,5) for _ in range(6)])

>>> a

tensor([[1, 3, 3, 2, 3],

[1, 1, 2, 1, 2],

[1, 2, 1, 0, 0],

[3, 3, 3, 1, 0],

[0, 2, 0, 0, 2],

[1, 1, 0, 3, 0]])>>> torch.nonzero(a)

tensor([[0, 0],

[0, 1],

[0, 2],

[0, 3],

[0, 4],

[1, 0],

[1, 1],

[1, 2],

[1, 3],

[1, 4],

[2, 0],

[2, 1],

[2, 2],

[3, 0],

[3, 1],

[3, 2],

[3, 3],

[4, 1],

[4, 4],

[5, 0],

[5, 1],

[5, 3]])