AI Features

Implementing and Training the Model with TensorFlow

Learn to implement and train the transformer model with TensorFlow.

We’ll now implement the model we just studied. First, let’s import a few things:

import tensorflow_hub as hub
import tensorflow as tf
import tensorflow.keras.backend as K

Implementing the ViT model

Next, we’re going to download the pretrained ViT model from TensorFlow Hub. We’ll be using a model submitted by Sayak Paul. You can see other ViT models here.

image_encoder = hub.KerasLayer("https://tfhub.dev/sayakpaul/vit_s16_fe/1", trainable=False)

We then define an input layer to input images and pass that to the image_encoder to get the final feature vector for that image:

image_input = tf.keras.layers.Input(shape=(224, 224, 3))
image_features = image_encoder(image_input)

We can look at the size of the final image representation by running:

print(f"Final representation shape: {image_features.shape}")

This will output:

Final representation shape: (None, 384)

Next, we’ll look at the details of how to implement the text-based transformer model, which will take in the image representation to generate the image caption.

Implementing the text-based decoder

Here, we’ll implement a transformer decoder model from the ground up. This is different from how we used transformer models before, where we downloaded a pretrained model and used them.

Before we implement the model itself, we’re going to implement two custom Keras layers: one for the self-attention mechanism and the other one to capture the functionality of a single layer in the transformer model. Let’s start with the self-attention layer.

Defining the self-attention layer

Here, we define the self-attention layer using the Keras subclassing API:

class SelfAttentionLayer(tf.keras.layers.Layer):
""" Defines the computations in the self-attention layer """
def __init__(self, d):
super(SelfAttentionLayer, self).__init__()
# Feature dimensionality of the output
self.d = d
def build(self, input_shape):
# Query weight matrix
self.Wq = self.add_weight(
shape=(input_shape[-1], self.d),
initializer='glorot_uniform',
trainable=True, dtype='float32'
)
# Key weight matrix
self.Wk = self.add_weight(
shape=(input_shape[-1], self.d),
initializer='glorot_uniform',
trainable=True, dtype='float32'
)
# Value weight matrix
self.Wv = self.add_weight(
shape=(input_shape[-1], self.d),
initializer='glorot_uniform',
trainable=True, dtype='float32'
)
def call(self, q_x, k_x, v_x, mask=None):
q = tf.matmul(q_x,self.Wq) #[None, t, d]
k = tf.matmul(k_x,self.Wk) #[None, t, d]
v = tf.matmul(v_x,self.Wv) #[None, t, d]
# Computing the final output
h = tf.keras.layers.Attention(causal=True)([
q, #q
v, #v
k, #k
], mask=[None, mask])
# [None, t, t] . [None, t, d] => [None, t, d]
return h
Implementing the self-attention layer

Here, we have to populate the logic for three functions:

__init__() and __build__(): Define various hyperparameters and layer initialization-specific logic.

call(): Computations that need to happen when the layer is called.

We define the dimensionality of the attention output, d, as an argument to the __init__() method. Next, in the __build__() method, we define three weight matrices: Wq, Wk, and Wv. These represent the weights of the query, key, and value, respectively.

Finally, in the call() method, we have the logic. It takes four inputs: query, key, value inputs, and an optional mask for values. We then compute the latent q, k, and v by multiplying with the corresponding weight matrices Wq, Wk, and Wv. To compute attention, we’ll be using the out-of-the-box layer tf.keras.layers.Attention. The tf.keras.layers.Attention() layer has several arguments. One that we care about here is setting causal=True.

By doing this, we’re instructing the layer to mask the tokens ...

Ask