Search⌘ K
AI Features

Create Model

Explore how to create image classification models with PyTorch by learning to instantiate models using the timm library. Understand key parameters such as pretrained weights, dropout rate, and class count to customize models for your dataset.

Introduction to model creation

Model creation is the first step in building an image classification model. In this lesson, we’ll learn how to create a new model based on our desired architecture.

Import

Place the following import statement at the top of our file to import the PyTorch Image Model (timm).

import timm

Instantiate a new model

Model instantiation is via the built-in create_model(modelName) function. The model name t is a required parameter, so we need to pass it in. We can now use it normally. To understand how the PyTorch model actually works, study the code in the following widget:

Python
import timm
model = timm.create_model('resnet50', pretrained=False, checkpoint_path="/app/resnet50_best.pth.tar", num_classes=4)
model.eval()
# get the total number of features in the model
print('The number of features in this dataset:', model.num_features)
# get the total number of classes in the model
print('The number of classes in the dataset:', model.num_classes)

Arguments

The create_model function accepts the following arguments:

  • model_name (str): This is the name of the model to instantiate.
  • pretrained (bool): If true, this loads pre-trained ImageNet-1k weights.
  • checkpoint_path (str): This is the path of the loaded checkpoint after model initialization.

If we have a custom-trained model, we simply change the checkpoint_path value to our model’s corresponding path.

Additional keyword arguments

In addition, we can pass in additional model-specific arguments. All models share the following common arguments:

  • drop_rate (float): This is the dropout rate for training (default: 0.0).
  • global_pool (str): This is the global pool type (default: ‘avg’).

Additional parameters are required for resnet50:

  • num_classes (int): This is the number of classes available.
  • in_chans (int): This is the number of channels for the image dataset.

Let’s look at the following example as a reference:

Python
import timm
model = timm.create_model(
"resnet50",
num_classes=4,
in_chans=3,
pretrained=False,
checkpoint_path="/app/resnet50_best.pth.tar"
)
model.eval()
print('The drop rate for resnet50:', model.drop_rate)
print('The global pool for resnet50:', model.global_pool)

Note: The SelectAdaptivePool2d class applies a 2D adaptive pooling over an input.

The PyTorch Image Model provides a single function for model creation. We can use create_model(modelName) to instantiate a new model.