Quick Look at RNN, LSTM, GRU and Attention
Much of the success in the field of Artificial Intelligence during the initial part of the decade starting 2010 can be attributed to Convolution Neural Networks (CNNs). Thanks to the advancements in hardware such as GPU and libraries and frameworks such as Tensorflow and Pytorch. CNNs are efficient in Computer Vision tasks such as Image classification, Object Localization and Segmentation, they were not successful when it came to NLP tasks. Main reason is that images are mostly 3D in nature whereas natural language is sequential or 2D. NLP, due to their sequential nature required an entirely different approach. Almost every other word in the English language has multiple meanings and you can call it Lexical Ambiguity.
Also, I would like to quote Stephen Clark — a full-time research scientist at Deepmind:
“Ambiguity is the greatest bottleneck to computational knowledge acquisition, the killer problem of all natural language processing.”
The answer to the problem started with Recurrent neural Networks fondly called as RNNs are great at modeling sequential data. If I were to draw a modern NLP timeline, I would place the RNNs followed by LSTM/GRU and then Transformer variants. Introduction of Transformers architecture from the famous “Attention is All you need” paper is the ImageNet moment for NLP. I would strongly recommend to learn RNNs, LSTMs, GRU and attention and their capabilities and limitations before venturing into transformers. In this post we will do just that. We will try to explore RNNs, LSTM/GRU and attention mechanism and possibly explore Transformers in my next post. I’m new to writing blogs or post and This is my first shot at blogging or posting. Just saying :)
Recurrent Neural Networks (RNNs)
As mentioned before, RNNs are good at modelling sequential data such as text, audio, time series data, etc. If you have worked with fully connected layers, you can imagine the most simplest of fully connected networks as shown below that has one input and one hidden layer and one output.
RNNs are similar to fully connected neural network with a feedback loop. Imagine output of the first step is feedback as input for next step and this enables the RNNs to have memory.
RNNs are basically the above network repeated n times for n time steps. At every time step, it can process the new information with the context of outputs of the prior steps. Though RNNs have memory but due to their inherent nature, they were not good at retaining it long enough. They suffer from short term memory. You can probably draw parallels to famous character Dory in Pixar’s movie “Finding Dory” :). In addition to the problem of short term memory, they were also slow to train. As you can see from the example below, RNNs cannot catch context that is far away.
Coreference resolution: In the image above, what you see is a problem with coreference resolution. The problems with pronouns such as “it”, “his”, “her”, “he”, “she”, “them” are called coreference resolution. It cannot be resolved unless the context can be retained. These are trivial for a human to understand that “his” is referring to “Tom” but can be an uphill task for a machine to resolve.
Long Short Term Memory (LSTM)
Before we move on to LSTM, wanted to quickly explain activation functions used inside LSTM cells. Activation functions are added to Neural Networks to add non-linearity and learnability otherwise neural networks are just linear and will have nothing to learn :). Tanh and Sigmoid are two of the activation functions that are used inside LSTM. Tanh squashes any incoming values between -1 and 1 whereas Sigmoid is added to keep the values between 0 and 1. Since Tanh squashes values between -1 and 1, it is used to keep the relevance of a particular word in the context whereas Sigmoid keeps between 0 and 1 to forget or keep a particular word in the context. We will understand these more in detail in the following section.
LSTM has a similar control flow as RNN but key difference being that the operations carried out within the LSTM cells. Significant part of LSTM is the memory that runs as a horizontal line at the top which carries the context. It has ability to forget, update and add context aided by the following 3 operations within a LSTM cell.
- Forget Gate — Forget information that may have entered in the last step and carry forward only the long term information that is required. This is achieved by the a sigmoid operation on top of previous step’s hidden output h at time step t-1 and the new input x at time step t. Sigmoid Operation is used as filter to forget information that is not required and carry forward only the information that is required.
- Input Gate — Operation of the input gate is to decide what information need to be added based on new input at time step t. Here also sigmoid is used to filter out unwanted information and tanh is used to raise the significance of key information that needs to added to the context.
- Update and Output Gate —( I prefer to combine both the operation as one logical step for easier understanding. Some practitioners prefer to treat them as two separate operations) Memory is updated in the previous gate operations and now it is time to output by filtering things out and send only required information that needs to be carried over to the next step as context vector output and hidden vector output from time step t.
If you consider the example below in continuation to RNN, LSTM is able to retain context better and be able to resolve the reference of “his” to Tom. If you consider each word of the sentence as an input at a time step to LSTM, LSTM will perform all the four operations as mentioned above at every time step and be able to carry forward the context or memory.
GRU — Gater Recurrent Unit
This is almost similar to LSTM with only minor variations. Growing popular and can be considered an alternative to LSTMs.
- GRU combines the forget and input gate of LSTM into an Update Gate.
- Also, merges the cell state and hidden state.
- It uses a Reset Gate to update the memory using old state at time step t-1 and new input at time step t and updates the memory and send as final output.
As it turns out that both LSTM and GRU aren’t great at retaining context when context is far away. They suffer from problem of vanishing gradients for longer sequences. The NLP world and their problems needed something more and that gave birth to the concept of “attention”. We will look at attention shortly after looking at RNN based Encoder and Decoder.
RNN Based Encoder And Decoder
Encoder and Decoder are built by placing RNNs one after the other. There are several types of Encoder and Decoder possible depending on the task at hand that we may want to solve.
- Vector to Sequence Models — Commonly Used in Language Generation Task.
- Sequence To Vector Models — commonly used in sentiment analysis problems.
- Sequence To Sequence Models — commonly seen in language translation applications.
All the above models work using following two core ideas
- Encoder has an input sequence x1,x2,x3 and encoder states c1,c2,c3. The encoder outputs a single output vector c which is passed as input to the decoder.
- Like Encoders, Decoders are also built by RNNs and decoder states s1,s2, s3 and their output denoted by y1, y2, … The Decoder needs to represent the entire input as one single vector and thereby causing information loss
One major shortfall of this architecture was inability to decipher entire information using this single vector for longer sentences. Again, The NLP world needed more to solve this problem.
How was the problem solved ????
Yes, Attention was introduced to RNN and LSTM Encoder and Decoders to solve the problem
Attention with Encoder and Decoder
As you can see in the above picture there is an attention mechanism that is added in between encoder and decoder block. The function of this attention mechanism is to taken the hidden h vectors from encoders and state s vectors of the decoder and generate context c vectors. All the inputs h and s are passed through a fully connected later followed by a Softmax to calculate attention weights a1, a2 and a3.
Assume a Language Translation task and your model needs to translate from one language to other. To generate every word in the destination language, the model performs following steps:
- The attention box takes in h vectors of the entire input sequence of the source language.
- The s vectors of any generated output of the destination language(for the first output it will be a zero vector as no outputs would have been generated prior to that. so generating second step of the output, it will take into account the s vector at the first output. so forth and so on.)
- Pass the h and s vectors through Fully Connected Layer
- Feed to a SoftMax and generate attention weights a1, a2, a3
- Then calculate how much of attention that is required for every word in the input. The heavier the attention, the more the focus on that word and lesser the attention less focus on that word.
Imagine if the sequence in the above image was generated as part of language generation task. If you look at the picture above you can see how much attention and where exactly the attention needs to be when generating every new word. If attention mechanism were to think like humans, This is what attention mechanism may do
- Who is jogging ? — “ Tom”
- Who went — “Tom”
- What action is being done in the park ? — “Jog”
- What is done to the friend ? — “met”
- Who met ? — “Tom”
- Who was met ? — “Friend”
Attention has more to it than what we just discussed. This is a very high level view of attention and we hardly scratched the surface. Attention was taken several notches higher ever since the introduction of “Attention Is All You Need” paper by Vaswani et al. I will come up with my next blog on Attention is all you need and Transformers very soon. In the meanwhile, please feel free to post your questions or mistakes if any in the above post. Thanks for your time and patience.