SoFunction
Updated on 2025-04-16

PyTorch Tensor Operation Guide (cat, stack, split and chunk)

In deep learning practice, the dimensional transformation of tensors is the basic skill in data processing and model construction. Whether it is the fusion of multimodal data (such as images and text) or the split and reorganization of batch data, the rational use of tensor operation functions can significantly optimize the calculation process. The cat, stack, split and chunk provided by PyTorch are the tools to solve such problems. The following will analyze its principles and applications one by one.

1.: Splicing tensors along a specified dimension

Function description

(concatenate) Connect multiple shape-compatible tensors along an existing dimension to generate a single tensor of higher dimensions. It is required that the size of the other dimensions except for the splicing dimensions must be completely consistent.

Sample code

import torch

a = ([[1, 2], [3, 4]])  # Shape (2, 2)b = ([[5, 6], [7, 8]])

# Splicing in 0 dimension (vertical direction)c = ([a, b], dim=0)  
print(c)
# Output:# tensor([[1, 2],
#         [3, 4],
#         [5, 6],
#         [7, 8]])

# Splicing in the first dimension (horizontal direction)d = ([a, b], dim=1)  
print(d)
# Output:# tensor([[1, 2, 5, 6],
#         [3, 4, 7, 8]])

2.: Create a new dimension stack tensor

Function description

The input tensor edge will be placedNewly created dimensionsTo do stacking, all tensors participating in stacking must have the exact same shape. The output tensor has one more dimension than the original tensor.

Sample code

a = ([1, 2, 3])
b = ([4, 5, 6])

# Stack along the 0th dimension to generate a two-dimensional tensorc = ([a, b], dim=0)  
print()  # ([2, 3])
print(c)
# Output:# tensor([[1, 2, 3],
#         [4, 5, 6]])

# Stack along the first dimension to generate two-dimensional tensorsd = ([a, b], dim=1)  
print()  # ([3, 2])
print(d)
# Output:# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

3.: Divide tensors by size

Function description

Split the input tensor into multiple sub-tensters based on the specified size. Supports two parameter forms:

  • Integer list: Each element represents the length of the corresponding shard
  • Integer N: Divide into N sub-tensters (the total length is required to be divisible)

Sample code

a = (9)  # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

# Segment by list size [2,3,4]parts = (a, [2, 3, 4], dim=0)
for part in parts:
    print(part)

'''
 Output:
 tensor([0, 1])
 tensor([2, 3, 4])
 tensor([5, 6, 7, 8])
 '''

# Divide into 3 partschunks = (a, 3, dim=0)
print([ for c in chunks])  # [([3]), ([3]), ([3])]

4.: Divide the tensor equally by quantity

Function description

Divide the input tensor evenly into N parts along the specified dimension. If it cannot be divided, the remaining elements are allocated to the previous shard.

Sample code

a = (10)  # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

# Divide into 3 parts, default to operation in dimension 0chunks = (a, chunks=3, dim=0)
for i, chunk in enumerate(chunks):
    print(f"Chunk {i}: {chunk}")

'''
 Output:
 Chunk 0: tensor([0, 1, 2, 3])
 Chunk 1: tensor([4, 5, 6])
 Chunk 2: tensor([7, 8, 9])
 '''

# Segment of two-dimensional tensors in the first dimensionb = (2,5)
chunks = (b, chunks=2, dim=1)
print(chunks[0].shape)  # ([2, 2])
print(chunks[1].shape)  # ([2, 3])

Comprehensive example: Segmentation and merge processing of image data

The following is a complete operation example combining image data to simulate the tensor operation scenario in the image preprocessing process:

Scene setting

Suppose we have a batch of RGB image data (size is3×256×256), the following operations need to be completed:

  1. Split the image into three RGB channels
  2. Normalize each channel independently
  3. Merged processing channels
  4. Stack multiple images into batches
  5. Split batches into training/verification sets

Code implementation

import torch
from torchvision import transforms
from PIL import Image
import  as plt

# 1. Load the sample image (H, W, C) -> Convert to (C, H, W)image = ('').convert('RGB')
image = ()(image)  # shape: ([3, 256, 256])

# 2. Use split to separate RGB channelsr_channel, g_channel, b_channel = (image, split_size_or_sections=1, dim=0)

''' Visualize the original channel
(figsize=(12,4))
(131), (r_channel.squeeze().numpy(), cmap='Reds'), ('Red')
(132), (g_channel.squeeze().numpy(), cmap='Greens'), ('Green')
(133), (b_channel.squeeze().numpy(), cmap='Blues'), ('Blue')
()
'''

# 3. Normalize each channel (sample operation)def normalize(tensor):
    return (tensor - ()) / ()

r_norm = normalize(r_channel)
g_norm = normalize(g_channel)
b_norm = normalize(b_channel)

# 4. Use cat to merge the processed channelsnormalized_img = ([r_norm, g_norm, b_norm], dim=0)
'''Observe the normalization effect
(normalized_img.permute(1,2,0))
('Normalized Image')
()
'''

# 5. Create a batch of simulated images (assuming there are 4 same images)batch_images = ([image]*4, dim=0)  # shape: (4, 3, 256, 256)

# 6. Use chunk to split batches into training set/verification settrain_set, val_set = (batch_images, chunks=2, dim=0)
print(f"Train set size: {train_set.shape}")  # ([2, 3, 256, 256])
print(f"Val set size: {val_set.shape}")      # ([2, 3, 256, 256])

Key operation analysis

step function effect Dimensional changes
Channel separation Extract individual color channels (3,256,256)→3 (1,256,256)
Data merge Merge the processed channel data 3 (1,256,256)→(3,256,256)
Batch construction Copy a single image into a batch of 4 images (3,256,256)→(4,3,256,256)
Batch division Divide batches into training/verification sets in proportion (4,3,256,256)→2×(2,3,256,256)

Extended application suggestions

  1. Data Enhancement: Perform different transformations on the channels after split (such as only the R channel is adjusted in contrast)
  2. Model input: The batch after stack can be directly input to the CNN network
  3. Distributed training: Use chunk to distribute data to multiple GPUs for processing
  4. Feature visualization: Extracting single channel of intermediate layer feature maps through split for analysis

Through this complete image processing flow example, you can clearly see:

  • split+catCombinations are often used in feature processing pipelines
  • stack+chunkCombination is a key tool for building batch processing systems
  • These operations provide flexible data control while maintaining computational efficiency while providing flexible data control capabilities

Summary and comparison

function Core role Dimensional changes Enter requirements
Splicing along existing dimensions constant Each tensor shape needs to be matched
Create a new dimension stack +1 dimension All tensors are identical in shape
Split by size constant Requires the division size or number of copies
Divide evenly by quantity constant Total length must be allocated

Application suggestions

  • Used when you need to merge similar data and retain the original dimensionscat
  • Used if you need to expand the dimension to represent batches or channelsstack
  • Priority is given to segmentation of sequence datasplit
  • Select when evenly dividing feature maps or tensorschunk

After mastering these tools, you will be able to manipulate tensor dimensions more flexibly to adapt to the construction needs of complex models!

This is the article about PyTorch tensor operation guide (cat, stack, split and chunk). For more related PyTorch tensor operation content, please search for my previous articles or continue browsing the related articles below. I hope everyone will support me in the future!