SimCLR Training Objective
Get introduced to SimCLR’s network architecture and its loss function.
We'll cover the following...
Now that we have two augmented versions of the input batch, 
Network architecture
As shown in the figure below, the two augmented versions of an image, 
The code example below implements the class SimCLR_Network that passes the input image to a resnet18 backbone (
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torchvision.transforms as Timport torchvisionimport torchvision.models as modelsfrom utils import Augmentfrom PIL import Imageclass SimCLR_Network(nn.Module):def __init__(self, embed_dim=512):super(SimCLR_Network, self).__init__()self.backbone = models.resnet18() # resnet18 backbonein_features = self.backbone.fc.in_featuresself.backbone.fc = nn.Identity() # remove the fc layer of resnet18# add mlp projection headself.projection = nn.Sequential(nn.Linear(in_features, embed_dim),nn.BatchNorm1d(embed_dim),nn.ReLU(),nn.Linear(in_features=embed_dim, out_features=embed_dim),nn.BatchNorm1d(embed_dim),)def forward(self, x):f = self.backbone(x)return self.projection(f)network = SimCLR_Network()batch = [Image.open("n02107683_Bernese_mountain_dog.jpeg"), Image.open("cat.jpg")]batch = [T.functional.to_tensor(img.resize((224, 224))) for img in batch]batch = torch.stack(batch)torchvision.utils.save_image(batch, "./output/image.png", normalize=True)augment = Augment(img_size=224)aug1, aug2 = augment(batch), augment(batch) # generate two augmented versions of batchtorchvision.utils.save_image(aug1, "./output/t1_image.png", normalize=True)torchvision.utils.save_image(aug2, "./output/t2_image.png", normalize=True)z1, z2 = network(aug1), network(aug2) # feature embeddingsprint("Shape of z1 and z2 is", z1.shape, "and ", z2.shape)
- Line 10: We implement the class - SimCLR_Networkthat passes the input image to a- resnet18backbone (- ) and an MLP projection head ( - ). 
- Line 13: We define the feature backbone - self.backboneas a- resnet18network.
- Lines 14–15: We remove the fully connected classification layer - resnet18by reinitializing it as an- nn.Identity()layer. The- self.backbonetakes an image (- ) and returns - —a dimensional features vector. 
- Lines 18–24: We define the projection head - self.projectionas an MLP layer using the- nn.Linear,- nn.ReLUand- nn.BatchNorm1dlayers. This projection layer takes- resnet18's- -dimensional features from - self.backboneand ...