SoFunction
Updated on 2024-10-29

Using reshape, view, and flatten in pytorch

When using pytorch to define the structure of neural networks, you will often see similar to the following .view() / flatten() usage, here to explain and demonstrate its use.

usage

reshape() can be called by (), or by (), and its effect is to change the shape of the tensor without changing the number of tensor elements.

() takes two arguments, the tensor to be changed and the shape to be changed.

(input, shape) → Tensor

  • input(Tensor)-The tensor to be reshaped
  • shape (python's tuple: ints) - new shape `

Case 1.

Input:

import torch
a = ([[0,1],[2,3]])
x = (a,(-1,))
print (x)
b = (4.)
Y = (a,(2,2))
print(Y)

Results:

tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])

usage

The principle of view() is very simple, in fact, it is to arrange the data in the original tensor, arranged in a row, and then according to the given parameters in the view() from a row in order to select the composition of the final tensor.

view() can have multiple parameters depending on how many dimensions of the tensor you want to get, typically set two parameters, also commonly used in neural networks (typically before full connectivity), to represent two dimensions.

view(h,w), h represents rows (want to change to several rows), when do not know to change to several rows, but know to change to several columns can take -1; w represents columns (want to change to several columns), when do not know to change to several columns, but know to change to several rows can take -1.

I. General usage (manual adjustment)

view() is equivalent to reshape, resize, and reshape the Tensor.

Case 2.

importation

import torch
a1 = (0,16)
print(a1)

exports

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

importation

a2 = (8, 2)
a3 = (2, 8)
a4 = (4, 4)
print(a2)
print(a3)
print(a4)

exports

tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])

II. Special usage: parameter -1 (automatic resizing)

A parameter in view is set to -1, which represents an automatic adjustment of the number of elements in this dimension to keep the total number of elements constant.

importation

import torch
a1 = (0,16)
print(a1)

exports

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

importation

a2 = (-1, 16)
a3 = (-1, 8)
a4 = (-1, 4)
a5 = (-1, 2)
a6 = (4*4, -1)
a7 = (1*4, -1)
a8 = (2*4, -1)
print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)

exports

tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])

(start_dim=1,end_dim=-1)

start_dim and end_dim represent the starting dimension and the ending dimension respectively, with default values of 1 and -1, where 1 represents the first dimension and -1 represents the last dimension. Combined to see the meaning is from the first dimension to the last dimension of all to the tensor to flatten. (Note: the dimension of the data is from 0, that is, there is a 0th dimension, the first dimension is not really the first).

Because its being used in neural networks where the input is a batch of data and dimension 0 is the batch (number of input data), it is common to flatten a piece of data into one dimension rather than a batch of data into one dimension. So () starts flattening from the first dimension by default.

Use () with default parameters

Official example given:

input = (32, 1, 5, 5)
# With default parameters
m = ()
output = m(input)
()
#([32, 25])
# With non-default parameters
m = (0, 2)
output = m(input)
()
#([160, 5])

The code at the beginning of # is a comment

The whole code means: given a random data of dimension (32,1,5,5).

1. Use () once first, using the default parameters:

m = ()

That is, from the first dimension spread to the last dimension, the dimension of the data starts from 0. The first dimension is actually the dimension represented by the second position of the data, which is 1 in the sample.

Thus the result after performing the spreading is also [32, 155] → [32, 25]

2. Then use () again with the specified parameters, i.e.

m = (0,2)

That means spreading from dimension 0 to dimension 2, 0 to 2, which corresponds to the first three dimensions as well.

Thus the result is [3215, 5] → [160, 25]

The () function is often used when writing classification neural networks, where after the last convolutional layer, an adaptive pooling layer is typically followed by an output vector of BCHW.

This is where the () function is used to flatten this vector into a vector of Bx's (where x = CHW), which is then fed into the FC layer.

在这里插入图片描述

sentence structure

 (input, start_dim=0, end_dim=-1)

input: a tensor, i.e. a tensor to be "flattened".

  • start_dim: The starting dimension of the "dilation".
  • end_dim: The end dimension of the "amortization".

This is similar to flattening a tensor, except that it is a function instead of a class, and it starts in dimension 0 by default.

Example 1:

import torch
data_pool = (2,2,3,3) # Simulate the output after the last pooling layer or adaptive pooling layer, Batchesize*c*h*w
print(data_pool)
y=(data_pool,1)
print(y)

Output results:

在这里插入图片描述

The result is a B*x vector.

summarize

The above is a personal experience, I hope it can give you a reference, and I hope you can support me more.