You can use numel() function (number of elements) of the torch to find the number of elements in a given tensor.
Here is an example:
>>> import torch>>> a=torch.tensor([1,2,3,4])>>> atensor([1, 2, 3, 4])>>> torch.numel(a)4>>> a=torch.randn(5,6)>>> torch.numel(a)30