ResNet Model Definition
Learn how to define a ResNet model.
Training large neural networks can take days or weeks. Once these networks are trained, we can use their weights and apply them to new tasks, i.e., transfer learning. As a result, we can fine-tune a new network and get good results in a short period. Let’s look at how we can fine-tune a pretrained ResNet network in JAX and Flax.
Prior to using transfer learning and fine-tuning the ResNet model, it's important to process the data, which was covered previously.
Pretrained ResNet models are trained on many classes. However, the dataset we have contains two classes. Therefore, we use the ResNet as the backbone and define a custom classification layer.
Create a Head network
Create a Head network with output as per the problem, in this case, a binary image classification.
from flax import linen as nnfrom functools import partialclass Head(nn.Module):'''head model'''batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)@nn.compactdef __call__(self, inputs, train: bool):output_n = inputs.shape[-1]x = self.batch_norm_cls(use_running_average=not train)(inputs)x = nn.Dropout(rate=0.25)(x, deterministic=not train)x = nn.Dense(features=output_n)(x)x = nn.relu(x)x = self.batch_norm_cls(use_running_average=not train)(x)x = nn.Dropout(rate=0.5)(x, deterministic=not train)x = nn.Dense(features=config["NUM_LABELS"])(x)return x
In the code above, we import the linen module from the flax library as nn to define the neural network architecture and the partial module from the functools library to create a partial function application. We define the Head class inheriting from the nn.Module to represent the head model. Inside this class:
Line 6: We call the
partial()function and apply thenn.BatchNormclass with a fixed value ofmomentumto define a partial functionbatch_norm_cls. We can use this partial function to create the instances of thenn.BatchNormlayer with a specificmomentumvalue.Lines 7–17: We define a
__call__()function and apply the@nn.compactdecorator to it. The__call__()function defines the model layers and the forward-passing of the input. Inside this function:Line 9: We calculate the output features from the last dimension of the
inputand store it inoutout_n.Lines 10–13: We call the
batch_norm_cls()function to apply the batch normalization layer to theinputs. We apply theDropoutlayer with therateof 25% on the output of the previous layer. We apply theDenselayer withoutput_nfeatures and ReLU activation, respectively.Lines 14–17: Similarly, we apply the batch normalization layer,
Dropoutlayer, with therateof 50% and theDenselayer with theNUM_LABELSfeatures. Lastly, we return the output.
Combine ResNet backbone with head
Combine the pretrained ResNet backbone with the custom head we created above.
from jax_resnet import pretrained_resnet, slice_variables, Sequentialimport jax.numpy as jnpclass Model(nn.Module):backbone: Sequentialhead: Headdef __call__(self, inputs, train: bool):x = self.backbone(inputs)# average pool layerx = jnp.mean(x, axis=(1, 2))x = self.head(x, train)return x
In the code above:
Lines 1–2: We import the required library modules:
pretrained_resnet,slice_variables, andSequentialfrom thejax_resnetlibrary and the JAX version of NumPy asjnp.Lines 4–13: We define the
Modelclass inheriting thenn.Module. We define thebackboneattribute of theSequentialtype and theheadattribute of theHeadtype. We define the__call__()function to apply the model to the given input. Inside this function:Line 9: We apply the
backbonemodel to the given input and store the output in the variablex.Line 11: We call the
jnp.mean()method to compute the mean of the valuexand update the valuex.Lines 12–13: We apply the
headmodel to the valuexand return the output.
Load pretrained ResNet-50
Next, we create a function that loads the pretrained ResNet model. We omit the last two layers of the network because we have defined a custom head. The function returns the ResNet model and its parameters. The model parameters are obtained using the slice_variables function.
def get_backbone_and_params(model_arch: str):if model_arch == 'resnet50':resnet_tmpl, params = pretrained_resnet(50)model = resnet_tmpl()else:raise NotImplementedError# get model & param structure for backbonestart, end = 0, len(model.layers) - 2backbone = Sequential(model.layers[start:end])backbone_params = slice_variables(params, start, end)return backbone, backbone_params
We define the get_backbone_and_params() function to retrieve the backbone model and related parameters. This function receives a string argument, model_arch, that specifies the model architecture to use. Inside this function:
Lines 3–7: We check the value of the
model_archargument. If the value isresnet50, we call thepretrained_resnet(50)method to load the ...