You can use either torch.narrow() function or apply slicing operation to select some rows/columns from a tensor. The narrow() function has the following format.
torch.narrow(input, dim, start, length) → Tensor
Where:
input – Your input tensor
dim (int) – the dimension along which to narrow (rows-0, columns-1)
start (int) – the starting row/column index
length (int) – how many row/column to select
Using narrow() function
>>> import torch
>>> a=torch.randn(5,4)
>>> a
tensor([[-0.9016, -0.6995, 1.3679, 0.1771],
[ 1.2528, -0.0611, 0.5726, 0.3936],
[ 2.0479, -0.7027, 1.1459, 0.8682],
[-1.4382, -1.5006, -0.1019, -0.2421],
[-0.7981, 1.2505, 0.4924, -0.5110]])
>>> torch.narrow(a,0,2,2) # select 2 rows starting from row_idx=2
tensor([[ 2.0479, -0.7027, 1.1459, 0.8682],
[-1.4382, -1.5006, -0.1019, -0.2421]])
>>> torch.narrow(a,1,1,3) # select 3 column starting from col_idx=1
tensor([[-0.6995, 1.3679, 0.1771],
[-0.0611, 0.5726, 0.3936],
[-0.7027, 1.1459, 0.8682],
[-1.5006, -0.1019, -0.2421],
[ 1.2505, 0.4924, -0.5110]])
Using slicing operation
>>> a[2:4,]
tensor([[ 2.0479, -0.7027, 1.1459, 0.8682],
[-1.4382, -1.5006, -0.1019, -0.2421]])
>>> a[:,1:4]
tensor([[-0.6995, 1.3679, 0.1771],
[-0.0611, 0.5726, 0.3936],
[-0.7027, 1.1459, 0.8682],
[-1.5006, -0.1019, -0.2421],
[ 1.2505, 0.4924, -0.5110]])
>>>