Distributed Training with JAX and Flax
Learn about distributed training in JAX and Flax.
We'll cover the following...
Training models on accelerators with JAX and Flax differ slightly from training with CPUs. For instance, the data needs to be replicated in the different devices when using multiple accelerators. After that, we need to execute the training on multiple devices and aggregate the results. Flax supports TPU and GPU accelerators.
This lesson will focus on training models with Flax and JAX using GPUs and TPUs.
Prior to training, it’s important to process the data and create a training state, which was covered in the earlier lesson.
Create training state
We now need to create parallel versions of our functions. Parallelization in JAX is done using the pmap function. pmap compiles a function with XLA and executes it on multiple devices.
from flax.training import train_stateimport optaximport functools@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))def create_train_state(rng, learning_rate, momentum):"""Creates initial `TrainState`."""cnn = CNN()params = cnn.init(rng, jnp.ones([1, size_image, size_image, 3]))['params']tx = optax.sgd(learning_rate, momentum)return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
In the code above:
Lines 1–3: We import the
train_statefrom theflax.trainingmodule,optax, andfunctoolslibraries.Line 5: We apply the
@functools.partial()decorator with thejax.pmapargument for parallel execution of thecreate_train_state()function. We set thestatic_broadcasted_argnums=(1, 2)to broadcast thelearning_rateandmomentumas static values.Lines 6–11: We define the
create_train_state()function that creates the initial state for model training. This function takes three arguments:rngis the random number generator key andlearning_rateandmomentumare the parameters of the optimizer. Inside this function:Lines 8–9: We create an instance
cnnof theCNNclass and get the initial model parametersparamsby calling theinit()method ofcnn. This method takes the random number generator key and a dummy input image of the JAX array of ones.Line 10: We define a stochastic gradient descent optimizer with the provided learning rate and momentum.
Line 11: We create and return the train state by calling the
create()method of thetrain_state.TrainStatemodule. This method takes ...