In deep learning practice, the dimensional transformation of tensors is the basic skill in data processing and model construction. Whether it is the fusion of multimodal data (such as images and text) or the split and reorganization of batch data, the rational use of tensor operation functions can significantly optimize the calculation process. The cat, stack, split and chunk provided by PyTorch are the tools to solve such problems. The following will analyze its principles and applications one by one.
1.: Splicing tensors along a specified dimension
Function description
(concatenate) Connect multiple shape-compatible tensors along an existing dimension to generate a single tensor of higher dimensions. It is required that the size of the other dimensions except for the splicing dimensions must be completely consistent.
Sample code
import torch a = ([[1, 2], [3, 4]]) # Shape (2, 2)b = ([[5, 6], [7, 8]]) # Splicing in 0 dimension (vertical direction)c = ([a, b], dim=0) print(c) # Output:# tensor([[1, 2], # [3, 4], # [5, 6], # [7, 8]]) # Splicing in the first dimension (horizontal direction)d = ([a, b], dim=1) print(d) # Output:# tensor([[1, 2, 5, 6], # [3, 4, 7, 8]])
2.: Create a new dimension stack tensor
Function description
The input tensor edge will be placedNewly created dimensionsTo do stacking, all tensors participating in stacking must have the exact same shape. The output tensor has one more dimension than the original tensor.
Sample code
a = ([1, 2, 3]) b = ([4, 5, 6]) # Stack along the 0th dimension to generate a two-dimensional tensorc = ([a, b], dim=0) print() # ([2, 3]) print(c) # Output:# tensor([[1, 2, 3], # [4, 5, 6]]) # Stack along the first dimension to generate two-dimensional tensorsd = ([a, b], dim=1) print() # ([3, 2]) print(d) # Output:# tensor([[1, 4], # [2, 5], # [3, 6]])
3.: Divide tensors by size
Function description
Split the input tensor into multiple sub-tensters based on the specified size. Supports two parameter forms:
- Integer list: Each element represents the length of the corresponding shard
- Integer N: Divide into N sub-tensters (the total length is required to be divisible)
Sample code
a = (9) # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) # Segment by list size [2,3,4]parts = (a, [2, 3, 4], dim=0) for part in parts: print(part) ''' Output: tensor([0, 1]) tensor([2, 3, 4]) tensor([5, 6, 7, 8]) ''' # Divide into 3 partschunks = (a, 3, dim=0) print([ for c in chunks]) # [([3]), ([3]), ([3])]
4.: Divide the tensor equally by quantity
Function description
Divide the input tensor evenly into N parts along the specified dimension. If it cannot be divided, the remaining elements are allocated to the previous shard.
Sample code
a = (10) # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) # Divide into 3 parts, default to operation in dimension 0chunks = (a, chunks=3, dim=0) for i, chunk in enumerate(chunks): print(f"Chunk {i}: {chunk}") ''' Output: Chunk 0: tensor([0, 1, 2, 3]) Chunk 1: tensor([4, 5, 6]) Chunk 2: tensor([7, 8, 9]) ''' # Segment of two-dimensional tensors in the first dimensionb = (2,5) chunks = (b, chunks=2, dim=1) print(chunks[0].shape) # ([2, 2]) print(chunks[1].shape) # ([2, 3])
Comprehensive example: Segmentation and merge processing of image data
The following is a complete operation example combining image data to simulate the tensor operation scenario in the image preprocessing process:
Scene setting
Suppose we have a batch of RGB image data (size is3×256×256
), the following operations need to be completed:
- Split the image into three RGB channels
- Normalize each channel independently
- Merged processing channels
- Stack multiple images into batches
- Split batches into training/verification sets
Code implementation
import torch from torchvision import transforms from PIL import Image import as plt # 1. Load the sample image (H, W, C) -> Convert to (C, H, W)image = ('').convert('RGB') image = ()(image) # shape: ([3, 256, 256]) # 2. Use split to separate RGB channelsr_channel, g_channel, b_channel = (image, split_size_or_sections=1, dim=0) ''' Visualize the original channel (figsize=(12,4)) (131), (r_channel.squeeze().numpy(), cmap='Reds'), ('Red') (132), (g_channel.squeeze().numpy(), cmap='Greens'), ('Green') (133), (b_channel.squeeze().numpy(), cmap='Blues'), ('Blue') () ''' # 3. Normalize each channel (sample operation)def normalize(tensor): return (tensor - ()) / () r_norm = normalize(r_channel) g_norm = normalize(g_channel) b_norm = normalize(b_channel) # 4. Use cat to merge the processed channelsnormalized_img = ([r_norm, g_norm, b_norm], dim=0) '''Observe the normalization effect (normalized_img.permute(1,2,0)) ('Normalized Image') () ''' # 5. Create a batch of simulated images (assuming there are 4 same images)batch_images = ([image]*4, dim=0) # shape: (4, 3, 256, 256) # 6. Use chunk to split batches into training set/verification settrain_set, val_set = (batch_images, chunks=2, dim=0) print(f"Train set size: {train_set.shape}") # ([2, 3, 256, 256]) print(f"Val set size: {val_set.shape}") # ([2, 3, 256, 256])
Key operation analysis
step | function | effect | Dimensional changes |
---|---|---|---|
Channel separation |
|
Extract individual color channels | (3,256,256)→3 (1,256,256) |
Data merge |
|
Merge the processed channel data | 3 (1,256,256)→(3,256,256) |
Batch construction |
|
Copy a single image into a batch of 4 images | (3,256,256)→(4,3,256,256) |
Batch division |
|
Divide batches into training/verification sets in proportion | (4,3,256,256)→2×(2,3,256,256) |
Extended application suggestions
- Data Enhancement: Perform different transformations on the channels after split (such as only the R channel is adjusted in contrast)
- Model input: The batch after stack can be directly input to the CNN network
- Distributed training: Use chunk to distribute data to multiple GPUs for processing
- Feature visualization: Extracting single channel of intermediate layer feature maps through split for analysis
Through this complete image processing flow example, you can clearly see:
-
split
+cat
Combinations are often used in feature processing pipelines -
stack
+chunk
Combination is a key tool for building batch processing systems - These operations provide flexible data control while maintaining computational efficiency while providing flexible data control capabilities
Summary and comparison
function | Core role | Dimensional changes | Enter requirements |
---|---|---|---|
|
Splicing along existing dimensions | constant | Each tensor shape needs to be matched |
|
Create a new dimension stack | +1 dimension | All tensors are identical in shape |
|
Split by size | constant | Requires the division size or number of copies |
|
Divide evenly by quantity | constant | Total length must be allocated |
Application suggestions:
- Used when you need to merge similar data and retain the original dimensions
cat
; - Used if you need to expand the dimension to represent batches or channels
stack
; - Priority is given to segmentation of sequence data
split
; - Select when evenly dividing feature maps or tensors
chunk
。
After mastering these tools, you will be able to manipulate tensor dimensions more flexibly to adapt to the construction needs of complex models!
This is the article about PyTorch tensor operation guide (cat, stack, split and chunk). For more related PyTorch tensor operation content, please search for my previous articles or continue browsing the related articles below. I hope everyone will support me in the future!