LSTM Model
Learn how to make and train neural networks using LSTM with JAX and Flax.
Define LSTM model in Flax
We are now ready to define the LSTM model in Flax. To design LSTMs in Flax, we use the LSTMCell or the OptimizedLSTMCell. The OptimizedLSTMCell is the efficient LSTMCell.
The LSTMCell.initialize_carry function is used to initialize the hidden state of the LSTM cell. It expects:
- A random number.
- The batch dimensions.
- The number of units.
Let’s use the setup method to define the LSTM model. The LSTM contains the following layers:
- An Embedding layer with the same number of features and length as defined in the vectorization layer.
- LSTM layers that pass data in one direction as specified by the
reverseargument. - A couple of Dense layers.
- A final Dense output layer.
from flax import linen as nnclass LSTMModel(nn.Module):def setup(self):self.embedding = nn.Embed(max_features, max_len)lstm_layer = nn.scan(nn.OptimizedLSTMCell,variable_broadcast="params",split_rngs={"params": False},in_axes=1,out_axes=1,length=max_len,reverse=False)self.lstm1 = lstm_layer()self.dense1 = nn.Dense(256)self.lstm2 = lstm_layer()self.dense2 = nn.Dense(128)self.lstm3 = lstm_layer()self.dense3 = nn.Dense(64)self.dense4 = nn.Dense(2)@nn.rematdef __call__(self, x_batch):x = self.embedding(x_batch)carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128)(carry, hidden), x = self.lstm1((carry, hidden), x)x = self.dense1(x)x = nn.relu(x)carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=64)(carry, hidden), x = self.lstm2((carry, hidden), x)x = self.dense2(x)x = nn.relu(x)carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=32)(carry, hidden), x = self.lstm3((carry, hidden), x)x = self.dense3(x)x = nn.relu(x)x = self.dense4(x[:, -1])return nn.log_softmax(x)
We import the linen module from the flax library as nn to define the LSTM model. In the code above, we define the LSTMModel class using nn.Module. Inside this class:
Lines 4–19: We define the
setup()function to define the layers and components of the model. Inside this function:Line 5: We define the Embedding layer to map the discrete input to the continuous vector.
Lines 6–12: We define the LSTM layer as
lstm_layer, where we call thescan()method ofnnto set up the LSTM layer with the given configurations.Lines 13–14: We call the defined
lstm_layer()to define the LSTM layer, followed by the Dense layer with256units.Lines 15–19: Similarly, we define another LSTM layer, a Dense layer with
128layers, the third layer of LSTM, and a Dense layer with64units. Lastly, we define the Dense layer with two units for the binary classification.
Line 21: We apply the
@nn.rematdecorator to the__call__()function for memory optimization and numerical stability. The@nn.rematdecorator saves memory when using LSTMs to compute long sequences.Lines 22–41: We define the
__call__()function to implement the forward pass. Inside this function:Line 23: We pass the given input through the first layer (the Embedding layer) of the model.
Lines 25–28: We call the
nn.OptimizedLSTMCell.initialize_carrymethod to initialize thecarryandhiddenstates and pass the output of the previous layer through the first LSTM layer, followed by the first Dense layer. Lastly, we apply the ReLU activation function.Lines 30–38: Similarly, we pass the output of the previous layer through the subsequent LSTM and Dense layers and apply the ReLU activation function.
Lines 40–41: We pass the output of the previous layer through the last Dense layer. Lastly, we apply the LogSoftmax activation and return the output.
We apply the ...