Challenge: Distributed Training with JAX and Flax
Test your understanding of image classification and distributed training.
We will perform distributed training using JAX and Flax in this challenge. We have imported all the necessary libraries for you.
Challenge 1: Load the dataset
In the /usr/local/notebooks
directory, we have a dataset in a zipped folder, cars_and_bikes.zip
, containing images from two classes: cars and bikes. There are twelve images from each class, making a total of 24 images. Load the dataset using the image paths and labels (each image is named with the class it belongs to along with a serial number, e.g., bike.0.jpg
or car.4.jpg
). Moreover, define a Dataset
class for loading it using the DataLoader to create training and validation sets.
Get hands-on with 1200+ tech skills courses.