Detailed explanation of Python torch. flatten of function case

  • 2021-11-24 02:25:45
  • OfStack

Look at the function parameters first:


torch.flatten(input, start_dim=0, end_dim=-1)

input: 1 tensor, that is, tensor to be "bulldozed".

start_dim: The starting dimension of "bulldozing".

end_dim: End dimension of "bulldozing".

First of all, if the default values of start_dim and end_dim are followed, this function will flatten input to an tensor with shape being [n] [n], where nn is the number of elements in input.

What if we want to set the starting dimension and the ending dimension ourselves?

Let's first look at what shape in tensor looks like:


t = torch.tensor([[[1, 2, 2, 1],
                   [3, 4, 4, 3],
                   [1, 2, 3, 4]],
                  [[5, 6, 6, 5],
                   [7, 8, 8, 7],
                   [5, 6, 7, 8]]])
print(t, t.shape)
 
 Run results: 
 
tensor([[[1, 2, 2, 1],
         [3, 4, 4, 3],
         [1, 2, 3, 4]],
 
        [[5, 6, 6, 5],
         [7, 8, 8, 7],
         [5, 6, 7, 8]]])
torch.Size([2, 3, 4])

We can see that the outermost square brackets contain two elements, so the first value of shape is 2; Similarly, the second layer of square brackets contains three elements, and the second value of shape is 3; The innermost square brackets contain four elements, and the second value of shape is 4.

Sample code:


x = torch.flatten(t, start_dim=1)
print(x, x.shape)
 
y = torch.flatten(t, start_dim=0, end_dim=1)
print(y, y.shape)
 
 
 Run results: 
 
tensor([[1, 2, 2, 1, 3, 4, 4, 3, 1, 2, 3, 4],
        [5, 6, 6, 5, 7, 8, 8, 7, 5, 6, 7, 8]]) 
torch.Size([2, 12])
 
tensor([[1, 2, 2, 1],
        [3, 4, 4, 3],
        [1, 2, 3, 4],
        [5, 6, 6, 5],
        [7, 8, 8, 7],
        [5, 6, 7, 8]]) 
torch.Size([6, 4])

As you can see, when start_dim = 11 and end_dim = 1-1, it bulldozes and merges the 11th dimension to the last dimension. When start_dim = 00 and end_dim = 11, it bulldozes and merges all dimensions from 00 to 11. The torch. nn. Flatten class and torch. Tensor. flatten method in pytorch are actually implemented based on the torch. flatten function above.


Related articles: