Challenge: LSTM in JAX and Flax
Test your understanding of the LSTM model.
We'll cover the following
Problem statement
In the notebook below, we’re performing sarcasm analysis on a textual dataset stored in a file named multimodal_sarcasm.csv
. Your job is to write code for the indicated code cells and functions. You are required to define the following:
- LSTM model class
- Train state creation function
- Training step function
- Evaluation step function
Try it yourself
Write your code in the Jupyter Notebook below:
Get hands-on with 1200+ tech skills courses.