Long-Short Term Memory (LSTM) neural network is one of the standards in sequential-input processing today. In this post, we don’t attempt to provide a self-contained tutorial on how it works, rather, since there are already excellent writings and other resources about LSTM online, we will point to those references when appropriate. This blog post is best treated as a brief introduction and starting point, where you can follow the anchor links to your useful information.
Recurrent Neural Network (RNN)
Unlike the vanilla Feed-Forward (FF) neural networks, the RNNs retain some information from the past to predict future events. We also call these FFs stateless (there is no state, no “memory”, no “history”) and RNNs stateful (keeping track of the state, having memory of the past, taking into account the history). The act of maintaining the memory is done by a self-loop, i.e. the output of a neuron does not only feed-forward, but also come back as an auxiliary input of the same neuron when a new, subsequent input comes in.
The RNNs are suitable for tasks that involves sequential or time-series inputs, e.g. to predict the next scene in a movie, predict the next word in a sentence. In general, every problem in which using past information helps, RNNs may be beneficial. Recognizing objects from a random image, for example, does not seem to be among those problems.
Note that the input sequences might not need to be of fixed length. This is one difference of RNNs to most other architectures, e.g. CNNs (Convolutional neural networks) or traditional ML algorithms like SVM, Linear regression, tree-based methods, etc.
The RNNs also introduce a fancy jargon: Backpropagation-through-time, which is nothing more than a normal backpropagation but is done over a sequence. By using a sequence instead of a set of individual elements, we mean the order of the elements matters. And this order is normally interpreted as a time-line, the first element stands in the first time-step , the second element in the second time-step , and so on. More on this later…
Traditional RNNs are quite simple with just, in each neuron, a weight matrix (or a state matrix, depending on how you call it) and a activation function. Intuitively, the weight matrix is there to allow the network to learn, and the is acting as a controller to add nonlinearity and allow weight matrix values to be negative (this is why the logistic function is not used in place of the ). However, this implementation has a big flaw, it is not protected from (backward-)gradient exploding and vanishing when more layers are stacked (i.e. deep learning). While the exploding problem can be somewhat easy to solve using a hard clipping and/or regularization, the vanishing is more difficult to deal with. LSTM presents a solution to this problem by reducing the multiplications in the backpropagation phase.
LSTM
LSTM is rather more complex than the old RNN architecture. It maintains a high-way cell state with the help of the forget gate, the input gate, and the output gate.
Note that the cell state, which stores the memory (what has happened in the current sequence up to the current point in time) is reset before a new sequence is fed to the network. On the other hand, the gates’ values are never reset, they are only updated when we feed new training input sequences.
One of the most (if not the most) influential explanations of LSTM on the internet is given by Christopher Olah in his blog. This essay is outstanding in various aspects:
Just keep in mind that each value should be thought of as a vector (or matrix) of large sizes, but not a simple scalar. However, you don’t need to care much about the specific shape (size) of the values yet, at least during your first read, as doing so would only overly-complicate things.
One amazing point this blog has made is to present LSTMs (and RNNs, in general) in the unrolled representation. Looking at RNNs from this point of view makes things much clearer in many aspects. The Backpropagation-through-time technique we mentioned above is exactly the vanilla backpropagation if we unroll the RNN this way. Just, however, note that in practice, the RNNs are indeed unrolled when training and inferencing, but only in the temporary memory. During the forward pass, the RNN layers are created and stored in memory. When the back-propagated update reaches the first RNN layer, all others but this first layer are removed.
Secondly, there is a great recording of a lecture on RNNs from Stanford University, which they generously publish on Youtube (slides is here). The video lasts 1h13m and we would recommend watching it wholly. The lecturer, Justin Johnson, doesn’t just describe RNNs and LSTM but also gives a bunch of inspiring examples about what this type of network can achieve. Included are: using RNNs on non-sequential inputs, imitating Shakespeare’s writings, auto-generation of Linux source code, and more. If you prefer reading to watching, you may want to check out Karpathy’s blog The Unreasonable Effectiveness of Recurrent Neural Networks, which presents the same insights.
As a side reference, there is a team of three guys from Google who did an interesting experiment to verify the strength of LSTM and GRU (Gated Recurrent Unit, a cousin of LSTM) versus other ten thousand other but similar architectures. Here are the key findings:
References: