1. Introduction
When supervised learning, we can use the fit() function to train the model, and iteratively optimize the model's parameters so that it can better fit the training data.
But when we want to control every little detail, we can write our own training cycle from scratch completely. At this point, a custom training algorithm is needed, but what if we want to benefit from the convenient features of fit(), such as callbacks, built-in distribution support or step fusion?
One of Keras's core principles is the gradual disclosure of complexity. We are always able to gradually enter a more underlying workflow. If the advanced features do not fully meet our requirements, we can gain more control over small details while retaining the corresponding number of advanced conveniences through custom fit().
When we need to customize the behavior of fit(), you should override the training step function of the Model class. This is the function called by the fit() function for each batch of data. Then you can call fit() as usual - and it will run your own learning algorithm.
2. Operation preparation
2.1 Settings
Please set it as follows before running
import os # This guide can only be run with the torch backend. ["KERAS_BACKEND"] = "torch" import torch import keras from keras import layers import numpy as np
2.2 Sample code
Let's start with a simple example to experience the method of customizing the operation in the fit() function in PyTorch.
First, you need to create a new class, which inherits from.
After creating this new class, you only need to rewrite ittrain_step(self, data)
method.
Running the above method returns a dictionary that maps indicator names (including losses) to their current values.
Enter parametersdata
It is passed tofit
As the content of the training data:
- If by call
fit(x, y, ...)
Pass NumPy array, thendata
Will be a tuple(x, y)
- If by call
fit(dataset, ...)
Pass aor
,So
data
Will be what the dataset generates in each batch.
existtrain_step()
In the body of the method, we implement a regular training update. Importantly, we passself.compute_loss()
Calculate the loss, the method encapsulates thecompile()
The loss function passed in the method.
Similarly, weIndicator call in
metric.update_state(y, y_pred)
, to update incompile()
The status of the metric passed in the method and query at the endto retrieve their current values.
class CustomModel(): def train_step(self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit()`. x, y = data # Call .zero_grad() to clear the leftover gradients # for the weights from the previous train step. self.zero_grad() # Compute loss y_pred = self(x, training=True) # Forward pass loss = self.compute_loss(y=y, y_pred=y_pred) # Call () on the loss to compute gradients # for the weights. () trainable_weights = [v for v in self.trainable_weights] gradients = [ for v in trainable_weights] # Update weights with torch.no_grad(): (gradients, trainable_weights) # Update metrics (includes the metric that tracks the loss) for metric in : if == "loss": metric.update_state(loss) else: metric.update_state(y, y_pred) # Return a dict mapping metric names to current value # Note that it will include the loss (tracked in ). return {: () for m in }
Run the code, the output is as follows
# Construct and compile an instance of CustomModel inputs = (shape=(32,)) outputs = (1)(inputs) model = CustomModel(inputs, outputs) (optimizer="adam", loss="mse", metrics=["mae"]) # Just use `fit` as usual x = ((1000, 32)) y = ((1000, 1)) (x, y, epochs=3)
Epoch 1/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 551us/step - mae: 0.6533 - loss: 0.6036 Epoch 2/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 522us/step - mae: 0.4013 - loss: 0.2522 Epoch 3/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 516us/step - mae: 0.3813 - loss: 0.2256 < at 0x299b7baf0>
3. Bottom operation
Of course, during actual operation, you can also not pass the loss function in the compile() method, but manually handle everything in train_step. Likewise, the same is true for indicators.
Here is an example of a more underlying level operation that uses compile() only to configure the optimizer:
We first create a Metric instance to track our loss and MAE scores (in the __init__() method).
Through a custom train_step(), the status of these metrics is updated (by calling update_state() on it), and then query them (by result()) to return their current average values, which will be displayed by the progress bar and passed to any callback.
Please note that during the run, you need to call reset_states() between each epoch to reset the metric! Otherwise, calling result() will return the average value since the start of training, and usually the average value of each epoch is used. The framework can do this for us: just list any metrics you want to reset in the model's metrics property. The model will call reset_states() at the beginning of each fit() epoch or at the beginning of an evaluate() call to reset the state of these objects.
class CustomModel(): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.loss_tracker = (name="loss") self.mae_metric = (name="mae") self.loss_fn = () def train_step(self, data): x, y = data # Call .zero_grad() to clear the leftover gradients # for the weights from the previous train step. self.zero_grad() # Compute loss y_pred = self(x, training=True) # Forward pass loss = self.loss_fn(y, y_pred) # Call () on the loss to compute gradients # for the weights. () trainable_weights = [v for v in self.trainable_weights] gradients = [ for v in trainable_weights] # Update weights with torch.no_grad(): (gradients, trainable_weights) # Compute our own metrics self.loss_tracker.update_state(loss) self.mae_metric.update_state(y, y_pred) return { "loss": self.loss_tracker.result(), "mae": self.mae_metric.result(), } @property def metrics(self): # We list our `Metric` objects here so that `reset_states()` can be # called automatically at the start of each epoch # or at the start of `evaluate()`. return [self.loss_tracker, self.mae_metric] # Construct an instance of CustomModel inputs = (shape=(32,)) outputs = (1)(inputs) model = CustomModel(inputs, outputs) # We don't pass a loss or metrics here. (optimizer="adam") # Just use `fit` as usual -- you can use callbacks, etc. x = ((1000, 32)) y = ((1000, 1)) (x, y, epochs=5)
Epoch 1/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 461us/step - loss: 0.2470 - mae: 0.3953 Epoch 2/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 456us/step - loss: 0.2386 - mae: 0.3910 Epoch 3/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 456us/step - loss: 0.2359 - mae: 0.3901 Epoch 4/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 480us/step - loss: 0.2013 - mae: 0.3572 Epoch 5/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 463us/step - loss: 0.1903 - mae: 0.3480 < at 0x299c5eec0>
3.1 Support sample weights and classification weights
The basic example at the beginning of the article does not mention sample weights, then if you want to support the sample_weight and class_weight parameters of the fit() method, you can follow the following steps:
Unpack sample_weight from data parameters
Pass it to compute_loss and update_state (of course, if you are not relying on the compile() method to set the loss and metrics, you can also apply it manually)
class CustomModel(): def train_step(self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit()`. if len(data) == 3: x, y, sample_weight = data else: sample_weight = None x, y = data # Call .zero_grad() to clear the leftover gradients # for the weights from the previous train step. self.zero_grad() # Compute loss y_pred = self(x, training=True) # Forward pass loss = self.compute_loss( y=y, y_pred=y_pred, sample_weight=sample_weight, ) # Call () on the loss to compute gradients # for the weights. () trainable_weights = [v for v in self.trainable_weights] gradients = [ for v in trainable_weights] # Update weights with torch.no_grad(): (gradients, trainable_weights) # Update metrics (includes the metric that tracks the loss) for metric in : if == "loss": metric.update_state(loss) else: metric.update_state(y, y_pred, sample_weight=sample_weight) # Return a dict mapping metric names to current value # Note that it will include the loss (tracked in ). return {: () for m in } # Construct and compile an instance of CustomModel inputs = (shape=(32,)) outputs = (1)(inputs) model = CustomModel(inputs, outputs) (optimizer="adam", loss="mse", metrics=["mae"]) # You can now use sample_weight argument x = ((1000, 32)) y = ((1000, 1)) sw = ((1000, 1)) (x, y, sample_weight=sw, epochs=3)
Epoch 1/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 499us/step - mae: 1.4332 - loss: 1.0769 Epoch 2/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 520us/step - mae: 0.9250 - loss: 0.5614 Epoch 3/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 502us/step - mae: 0.6069 - loss: 0.2653 < at 0x299c82bf0>
3.2 Provide custom evaluation steps
If you want to call()
When customizing the evaluation steps, how do we do it? Then we will rewrite it in the exact same waytest_step
。
class CustomModel(): def test_step(self, data): # Unpack the data x, y = data # Compute predictions y_pred = self(x, training=False) # Updates the metrics tracking the loss loss = self.compute_loss(y=y, y_pred=y_pred) # Update the metrics. for metric in : if == "loss": metric.update_state(loss) else: metric.update_state(y, y_pred) # Return a dict mapping metric names to current value. # Note that it will include the loss (tracked in ). return {: () for m in } # Construct an instance of CustomModel inputs = (shape=(32,)) outputs = (1)(inputs) model = CustomModel(inputs, outputs) (loss="mse", metrics=["mae"]) # Evaluate with our custom test_step x = ((1000, 32)) y = ((1000, 1)) (x, y)
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 325us/step - mae: 0.4427 - loss: 0.2993 [0.2726495862007141, 0.42286917567253113]
4. Complete application examples
To integrate what we learned earlier, we will demonstrate operations in a custom fit() function in PyTorch through an end-to-end GAN (generated adversarial network) example.
In this example, we will consider:
- A network of generators used to generate 28x28x1 images.
- A network of discriminators for classifying 28x28x1 images into two categories ("false" and "true").
- Each network has an optimizer.
- A loss function for training a discriminator.
First, we need to define the network structure of the generator and discriminator. For brevity, we won't write down the definition of each layer in detail, but you can imagine that the generator network takes noise as input and outputs the image, while the discriminator network takes the image as input and outputs a probability value that indicates whether the input image is real (from the training set) or false (generated by the generator).
Here is the general process of GAN training:
-
Initialize generator and discriminator networks:
- Defines the model structure of the generator and discriminator.
- Compile the discriminator network and specify a loss function (such as binary cross entropy) and an optimizer (such as Adam).
-
Training discriminator:
- For a batch of real images, calculate the discriminator's loss (using real tag 1).
- A batch of fake images is generated by the generator and the discriminator's loss to the fake image is calculated (using fake tag 0).
- Add the two losses and perform a gradient descent update on the discriminator.
-
Training generator:
- Generate a batch of fake images.
- Use a discriminator to predict these false images to obtain probability values.
- Using the discriminator's prediction as the label (the image we want the generator to generate is considered true by the discriminator), calculate the generator's loss (this is usually achieved by passing the discriminator's prediction to some kind of loss function, such as binary cross entropy or mean square error).
- Perform a gradient descent update to the generator using the calculated loss.
-
Notice: When training the generator, we need to set the discriminator's weight to be non-trainable (because we only want to update the generator's weight). This can be called before training the generator
= False
To achieve it.
-
Loop iteration:
- Repeat steps 2 and 3 multiple times to train the GAN.
-
Evaluate GAN on test set:
- Use a trained generator to generate images and visualize these images to evaluate the performance of the GAN.
# Create the discriminator discriminator = ( [ (shape=(28, 28, 1)), layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"), (negative_slope=0.2), layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"), (negative_slope=0.2), layers.GlobalMaxPooling2D(), (1), ], name="discriminator", ) # Create the generator latent_dim = 128 generator = ( [ (shape=(latent_dim,)), # We want to generate 128 coefficients to reshape into a 7x7x128 map (7 * 7 * 128), (negative_slope=0.2), ((7, 7, 128)), layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), (negative_slope=0.2), layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), (negative_slope=0.2), layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), ], name="generator", )
Below is a fully functional GAN class, which rewrittencompile()
Method to use your own signature andtrain_step
The entire GAN algorithm is implemented with 17 lines of code:
class GAN(): def __init__(self, discriminator, generator, latent_dim): super().__init__() = discriminator = generator self.latent_dim = latent_dim self.d_loss_tracker = (name="d_loss") self.g_loss_tracker = (name="g_loss") self.seed_generator = (1337) = True @property def metrics(self): return [self.d_loss_tracker, self.g_loss_tracker] def compile(self, d_optimizer, g_optimizer, loss_fn): super().compile() self.d_optimizer = d_optimizer self.g_optimizer = g_optimizer self.loss_fn = loss_fn def train_step(self, real_images): device = "cuda" if .is_available() else "cpu" if isinstance(real_images, tuple): real_images = real_images[0] # Sample random points in the latent space batch_size = real_images.shape[0] random_latent_vectors = ( shape=(batch_size, self.latent_dim), seed=self.seed_generator ) # Decode them to fake images generated_images = (random_latent_vectors) # Combine them with real images real_images = (real_images, device=device) combined_images = ([generated_images, real_images], axis=0) # Assemble labels discriminating real from fake images labels = ( [ ((batch_size, 1), device=device), ((batch_size, 1), device=device), ], axis=0, ) # Add random noise to the labels - important trick! labels += 0.05 * (, seed=self.seed_generator) # Train the discriminator self.zero_grad() predictions = (combined_images) d_loss = self.loss_fn(labels, predictions) d_loss.backward() grads = [ for v in .trainable_weights] with torch.no_grad(): self.d_optimizer.apply(grads, .trainable_weights) # Sample random points in the latent space random_latent_vectors = ( shape=(batch_size, self.latent_dim), seed=self.seed_generator ) # Assemble labels that say "all real images" misleading_labels = ((batch_size, 1), device=device) # Train the generator (note that we should *not* update the weights # of the discriminator)! self.zero_grad() predictions = ((random_latent_vectors)) g_loss = self.loss_fn(misleading_labels, predictions) grads = g_loss.backward() grads = [ for v in .trainable_weights] with torch.no_grad(): self.g_optimizer.apply(grads, .trainable_weights) # Update metrics and return their value. self.d_loss_tracker.update_state(d_loss) self.g_loss_tracker.update_state(g_loss) return { "d_loss": self.d_loss_tracker.result(), "g_loss": self.g_loss_tracker.result(), }
The following are the results of the operation
# Prepare the dataset. We use both the training & test MNIST digits. batch_size = 64 (x_train, _), (x_test, _) = .load_data() all_digits = ([x_train, x_test]) all_digits = all_digits.astype("float32") / 255.0 all_digits = (all_digits, (-1, 28, 28, 1)) # Create a TensorDataset dataset = ( torch.from_numpy(all_digits), torch.from_numpy(all_digits) ) # Create a DataLoader dataloader = (dataset, batch_size=batch_size, shuffle=True) gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim) ( d_optimizer=(learning_rate=0.0003), g_optimizer=(learning_rate=0.0003), loss_fn=(from_logits=True), ) (dataloader, epochs=1)
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 1582s 1s/step - d_loss: 0.3581 - g_loss: 2.0571 < at 0x299ce1840>
The above is the detailed content of the operation code in the custom fit() function in PyTorch. For more information about PyTorch custom fit(), please follow my other related articles!