Create Model
Explore how to create image classification models with PyTorch by learning to instantiate models using the timm library. Understand key parameters such as pretrained weights, dropout rate, and class count to customize models for your dataset.
We'll cover the following...
Introduction to model creation
Model creation is the first step in building an image classification model. In this lesson, we’ll learn how to create a new model based on our desired architecture.
Import
Place the following import statement at the top of our file to import the PyTorch Image Model (timm).
import timm
Instantiate a new model
Model instantiation is via the built-in create_model(modelName) function. The model name t is a required parameter, so we need to pass it in. We can now use it normally. To understand how the PyTorch model actually works, study the code in the following widget:
Arguments
The create_model function accepts the following arguments:
model_name(str): This is the name of the model to instantiate.pretrained(bool): If true, this loads pre-trained ImageNet-1k weights.checkpoint_path(str): This is the path of the loaded checkpoint after model initialization.
If we have a custom-trained model, we simply change the checkpoint_path value to our model’s corresponding path.
Additional keyword arguments
In addition, we can pass in additional model-specific arguments. All models share the following common arguments:
drop_rate(float): This is the dropout rate for training (default: 0.0).global_pool(str): This is the global pool type (default: ‘avg’).
Additional parameters are required for resnet50:
num_classes(int): This is the number of classes available.in_chans(int): This is the number of channels for the image dataset.
Let’s look at the following example as a reference:
Note: The
SelectAdaptivePool2dclass applies a 2D adaptive pooling over an input.
The PyTorch Image Model provides a single function for model creation. We can use create_model(modelName) to instantiate a new model.