Writing a Custom LSTM Cell in Pytorch
Implement an LSTM cell from scratch in Pytorch.
We'll cover the following
Creating an LSTM network in Pytorch is pretty straightforward.
import torch.nn as nn
## input_size -> N in the equations
## hidden_size -> H in the equations
layer = nn.LSTM(input_size= 10, hidden_size=20, num_layers=2)
Note that the number of layers is the number of cells that are connected. So this network will have LSTM cells connected together. We will see how in the next lesson. For now, we will focus on the simple LSTM cell based on the equations.
It is ideal to build an LSTM cell entirely from scratch. We have our equations for each gate, so all we have to do is transform them into code and connect them. As an example, a code template as well as the input gate will be provided and you will have to do the rest.
The originally proposed equations that we described are:
Simplification of LSTM equations
However, modern deep learning frameworks use a slightly simpler version of the LSTM. Actually, they disregard from Equation (1) and (2). And you will do the same. This results in a less complex model that is easier to optimize. Thus, we will implement the following equations in this exercise:
If this exercise feels too difficult, don’t be discouraged. It is. It may seem that it is simply an implementation of a few equations, but it is not. Feel free to give it a shot but also to move on if you get stuck.
Note that the code below will produce an error when executing for the first time. Don’t be alarmed by it. You can continue with the exercise as you will normally do.
Get hands-on with 1400+ tech skills courses.