Skip to content

How to use the Flax Linen API to build a convolutional neural network model and train it for image classification (using TensorFlow Datasets).

Notifications You must be signed in to change notification settings

8bitmp3/JAX-Flax-Tutorial-Image-Classification-with-Linen

Repository files navigation

The annotated MNIST image classification example with Flax Linen and Optax

UPDATE: Use the up-to-date Flax Quickstart on the official Flax site.


Author: @8bitmp3

This tutorial uses Flax—a high-performance deep learning library for JAX designed for flexibility—to show you how to construct a simple convolutional neural network (CNN) using the Linen API and Optax and train the network for image classification on the MNIST dataset.

If you're new to JAX, check out:

To learn more about Flax and its Linen API, refer to:

This tutorial has the following workflow:

  • Perform a quick setup
  • Build a convolutional neural network model with the Linen API that classifies images
  • Define a loss and accuracy metrics function
  • Create a dataset function with TensorFlow Datasets
  • Define training and evaluation functions
  • Load the MNIST dataset
  • Initialize the parameters with PRNGs and instantiate the optimizer with Optax
  • Train the network and evaluate it

If you're using Google Colaboratory (Colab), enable the GPU acceleration (Runtime > Change runtime type > Hardware accelerator:GPU).

Setup

  1. Install JAX, Flax, Optax, and TensorFlow Datasets (TFDS). Flax can use any data-loading pipeline and this example demonstrates how to utilize TFDS.
!pip install --upgrade -q pip jax jaxlib flax optax tensorflow-datasets
  1. Import JAX, JAX NumPy (which lets you run code on GPUs and TPUs), Flax, ordinary NumPy, and TFDS.
import jax
import jax.numpy as jnp               # JAX NumPy

from flax import linen as nn          # The Linen API
from flax.training import train_state
import optax                          # The Optax gradient processing and optimization library

import numpy as np                    # Ordinary NumPy
import tensorflow_datasets as tfds    # TFDS for MNIST

Build a model

Build a convolutional neural network with the Flax Linen API by subclassing flax.linen.Module. Because the architecture in this example is relatively simple—you're just stacking layers—you can define the inlined submodules directly within the __call__ method and wrap it with the @compact decorator (flax.linen.compact).

class CNN(nn.Module):

  @nn.compact
  # Provide a constructor to register a new parameter 
  # and return its initial value
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1)) # Flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)    # There are 10 classes in MNIST
    return x

Create a metrics function

For loss and accuracy metrics, create a separate function:

  • Optax has a built-in softmax cross-entropy loss (optax.softmax_cross_entropy). You will be defining and computing the loss inside a training step function later as follows:
  • The labels can be one-hot encoded with jax.nn.one_hot, as demonstrated below.
def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics

The dataset

Define a function that:

  • Uses TFDS to load and prepare the MNIST dataset; and
  • Converts the samples to floating-point numbers.
def get_datasets():
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  # Split into training/test sets
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  # Convert to floating-points
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
  return train_ds, test_ds

Training and evaluation functions

  1. Write a training step function that:
  • Evaluates the neural network given the parameters and a batch of input images with the flax.linen.Module.apply method.
  • Defines and computes the cross_entropy_loss function.
  • Evaluates the loss function and its gradient using jax.value_and_grad (check the JAX autodiff cookbook to learn more).
  • Applies a pytree of gradients (flax.training.train_state.TrainState.apply_gradients) to the optimizer to update the model's parameters.
  • Returns the optimizer state and computes the metrics using compute_metrics (defined earlier).

Use JAX's @jit decorator to trace the entire train_step function and just-in-time(JIT-compile with XLA into fused device operations that run faster and more efficiently on hardware accelerators.

@jax.jit
def train_step(state, batch):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = jnp.mean(optax.softmax_cross_entropy(
        logits=logits, 
        labels=jax.nn.one_hot(batch['label'], num_classes=10)))
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, batch['label'])
  return state, metrics
  1. Create a jit-compiled function that evaluates the model on the test set using flax.linen.Module.apply:
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])
  1. Define a training function for one epoch that:
  • Shuffles the training data before each epoch using jax.random.permutation that takes a PRNGKey as a parameter (discussed in more detail later in this tutorial and in JAX - the sharp bits).
  • Runs an optimization step for each batch.
  • Retrieves the training metrics from the device with jax.device_get and computes their mean across each batch in an epoch.
  • Returns the optimizer state with updated parameters and the training loss and accuracy metrics (training_epoch_metrics).
def train_epoch(state, train_ds, batch_size, epoch, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  batch_metrics = []

  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {
      k: np.mean([metrics[k] for metrics in training_batch_metrics])
      for k in training_batch_metrics[0]}

  print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

  return state, training_epoch_metrics
  1. Create a model evaluation function that:
  • Evalues the model on the test set.
  • Retrieves the evaluation metrics from the device with jax.device_get.
  • Copies the metrics data stored in a JAX pytree.
  • Returns the test loss and accuracy.
def eval_model(model, test_ds):
  metrics = eval_step(model, test_ds)
  metrics = jax.device_get(metrics)
  eval_summary = jax.tree_map(lambda x: x.item(), metrics)
  return eval_summary['loss'], eval_summary['accuracy']

Load the dataset

Download the dataset and preprocess it with get_datasets you defined earlier:

train_ds, test_ds = get_datasets()

Initialize the parameters with PRNGs and instantiate the optimizer

  1. PRNGs: Before you start training the model, you need to randomly initialize the parameters.

In NumPy, you would usually use the stateful pseudorandom number generators (PRNG).

JAX, however, uses an explicit PRNG (refer to JAX - the sharp bits for details):

  • Get one PRNGKey.
  • split it to get a second key that you'll use for parameter initialization.

Note that in JAX and Flax you can have separate PRNG chains (with different names, such as rng and init_rng below) inside Modules for different applications. (Learn more about PRNG chains and JAX PRNG design.)

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
  1. Instantiate the CNN model and initialize its parameters using a PRNG:
cnn = CNN()
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']
  1. Instantiate the SGD optimizer with Optax:
nesterov_momentum = 0.9
learning_rate = 0.001
tx = optax.sgd(learning_rate=learning_rate, nesterov=nesterov_momentum)
  1. Create a TrainState data class that applies the gradients and updates the optimizer state and parameters.
state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

Train the network and evaluate it

  1. Set the default number of epochs and the size of each batch:
num_epochs = 10
batch_size = 32
  1. Finally, begin training and evaluating the model over 10 epochs:
  • For your training function (train_epoch), you need to pass a PRNG key used to permute image data during shuffling. Since you have created a PRNG key when initializing the parameters in your nework, you just need to split or "fork" the PRNG state into two (while maintaining the usual desirable PRNG properties) to get a new subkey (input_rng, in this example) and the previous key (rng). Use jax.random.split to carry this out. (Learn more about JAX PRNG design.)
  • Run an optimization step over a training batch (train_epoch).
  • Evaluate on the test set after each training epoch (eval_model).
  • Retrieve the metrics from the device and print them.
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))
Training - epoch: 1, loss: 1.7941, accuracy: 62.73
Testing - epoch: 1, loss: 0.93, accuracy: 82.31
Training - epoch: 2, loss: 0.6114, accuracy: 85.10
Testing - epoch: 2, loss: 0.44, accuracy: 88.47
Training - epoch: 3, loss: 0.4128, accuracy: 88.40
Testing - epoch: 3, loss: 0.36, accuracy: 89.89
Training - epoch: 4, loss: 0.3598, accuracy: 89.67
Testing - epoch: 4, loss: 0.32, accuracy: 90.81
Training - epoch: 5, loss: 0.3280, accuracy: 90.50
Testing - epoch: 5, loss: 0.30, accuracy: 91.54
Training - epoch: 6, loss: 0.3047, accuracy: 91.18
Testing - epoch: 6, loss: 0.28, accuracy: 91.94
Training - epoch: 7, loss: 0.2853, accuracy: 91.71
Testing - epoch: 7, loss: 0.26, accuracy: 92.26
Training - epoch: 8, loss: 0.2680, accuracy: 92.15
Testing - epoch: 8, loss: 0.24, accuracy: 92.90
Training - epoch: 9, loss: 0.2522, accuracy: 92.72
Testing - epoch: 9, loss: 0.23, accuracy: 93.15
Training - epoch: 10, loss: 0.2384, accuracy: 92.99
Testing - epoch: 10, loss: 0.22, accuracy: 93.56

About

How to use the Flax Linen API to build a convolutional neural network model and train it for image classification (using TensorFlow Datasets).

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published