(start_dim=1, end_dim=- 1)
Role: Spreads a continuous range of dimensions into a tensor. Often found in (), usually written after some neural network model, used to process the output of the neural network model to get tensor type data.
There are two parameters, start_dim and end_dim, respectively, the starting dimension and the ending dimension, the default values are 1 and -1, where 1 means the first dimension and -1 means the last dimension. Combined to see the meaning is from the first dimension to the last dimension of all to the tensor spread. (Note: the dimension of the data is from 0, that is, there is a 0th dimension, the first dimension is not really the first)
Ditto if I write it that way:
= (start_dim=2, end_dim=3)
The meaning then is to start from the second dimension and give a spread all the way to the third dimension, that is, to spread the 2 and 3 dimensions.
Example given on the official website:
input = (32, 1, 5, 5) # With default parameters m = () output = m(input) () #([32, 25]) # With non-default parameters m = (0, 2) output = m(input) () #([160, 5])
The code at the beginning of # is a comment
The whole code means: given a random data of dimension (32, 1, 5, 5).
1. Use () once first, using the default parameters:
m = ()
That is, from the first dimension spread to the last dimension, the data starts at 0. The first dimension is actually the dimension represented by the second position of the data, which is 1 in the sample.
Thus the result after performing the spreading is also [32, 1 × 5 × 5] ➡ [32, 25]
2. Then use () with the specified parameters once more, i.e.
m = (0, 2)
That means spreading from dimension 0 to dimension 2, 0 to 2, which corresponds to the first three dimensions as well.
The result is therefore [32 × 1 × 5, 5] ➡ [160, 5]
Thus the result after performing the spreading is also [32, 1*5*5] ➡ [32, 25]
Example 1
convolution formula
import torch import as nn input = (32, 1, 5, 5) m = ( nn.Conv2d(1, 32, 5, 1, 1), # By convolution, get ([32, 32, 3, 3]) ()) output = m(input) print(()) >> ([32, 288])
Example 2
import torch import as nn input = (32, 1, 5, 5) m = ( nn.Conv2d(1, 32, 5, 1, 1), # By convolution, get ([32, 32, 3, 3]) (start_dim=0)) output = m(input) print(()) >>([9216])
summarize
to this article on pytorch () function explains the article is introduced to this, more related pytorch () function explains the content please search my previous posts or continue to browse the following related articles I hope you will support me in the future!