+3 votes
in Programming Languages by (60.0k points)
How can I count the number of elements in a tensor which are greater than a given value "k"?

1 Answer

+1 vote
by (354k points)
selected by
 
Best answer

First, you need to find which all elements of a tensor are greater than the given value, and then you can apply the torch.numel() function to the returned tensor to get the count.

Here is an example:

>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.0457, -0.4924, -0.7026,  0.0567],
        [-0.5104, -0.1395, -0.3003,  0.8491],
        [ 2.2846,  0.5619, -0.1806,  0.9625],
        [ 0.7884,  1.1767,  2.0025, -0.0589],
        [-0.1579,  0.8199, -0.5279,  0.2966],
        [ 0.0946, -0.7405,  0.4907,  1.3673]])
>>> a>1
tensor([[False, False, False, False],
        [False, False, False, False],
        [ True, False, False, False],
        [False,  True,  True, False],
        [False, False, False, False],
        [False, False, False,  True]])
>>> torch.numel(a[a>1])
4


...