LSTM: where to start?

A beautiful scene

Long-Short Term Memory (LSTM) neural network is one of the standards in sequential-input processing today. In this post, we don’t attempt to provide a self-contained tutorial on how it works, rather, since there are already excellent writings and other resources about LSTM online, we will point to those references when appropriate. This blog post is best treated as a brief introduction and starting point, where you can follow the anchor links to your useful information.

Recurrent Neural Network (RNN)

Unlike the vanilla Feed-Forward (FF) neural networks, the RNNs retain some information from the past to predict future events. We also call these FFs stateless (there is no state, no “memory”, no “history”) and RNNs stateful (keeping track of the state, having memory of the past, taking into account the history). The act of maintaining the memory is done by a self-loop, i.e. the output of a neuron does not only feed-forward, but also come back as an auxiliary input of the same neuron when a new, subsequent input comes in.

Ff Vs Rnn
Feed-forward versus Recurrent’s neuron

The RNNs are suitable for tasks that involves sequential or time-series inputs, e.g. to predict the next scene in a movie, predict the next word in a sentence. In general, every problem in which using past information helps, RNNs may be beneficial. Recognizing objects from a random image, for example, does not seem to be among those problems.

Note that the input sequences might not need to be of fixed length. This is one difference of RNNs to most other architectures, e.g. CNNs (Convolutional neural networks) or traditional ML algorithms like SVM, Linear regression, tree-based methods, etc.

The RNNs also introduce a fancy jargon: Backpropagation-through-time, which is nothing more than a normal backpropagation but is done over a sequence. By using a sequence instead of a set of individual elements, we mean the order of the elements matters. And this order is normally interpreted as a time-line, the first element stands in the first time-step t_1, the second element in the second time-step t_2, and so on. More on this later…

Traditional RNNs are quite simple with just, in each neuron, a weight matrix (or a state matrix, depending on how you call it) and a tanh activation function. Intuitively, the weight matrix is there to allow the network to learn, and the tanh is acting as a controller to add nonlinearity and allow weight matrix values to be negative (this is why the logistic function is not used in place of the tanh). However, this implementation has a big flaw, it is not protected from (backward-)gradient exploding and vanishing when more layers are stacked (i.e. deep learning). While the exploding problem can be somewhat easy to solve using a hard clipping and/or regularization, the vanishing is more difficult to deal with. LSTM presents a solution to this problem by reducing the multiplications in the backpropagation phase.

LSTM

LSTM is rather more complex than the old RNN architecture. It maintains a high-way cell state with the help of the forget gate, the input gate, and the output gate.

  • By “high-way cell state”, we mean that the cell state is less interrupted during the backpropagation, making it unlikely to experience the gradient vanishing problem.
  • The forget gate decides which and how much information from the previous steps should be forgotten given the new input.
  • The input gate determines how new information is added to the cell state.
  • Lastly, the output gate decides what the output should be, given the input and the current cell state.

Note that the cell state, which stores the memory (what has happened in the current sequence up to the current point in time) is reset before a new sequence is fed to the network. On the other hand, the gates’ values are never reset, they are only updated when we feed new training input sequences.

One of the most (if not the most) influential explanations of LSTM on the internet is given by Christopher Olah in his blog. This essay is outstanding in various aspects:

  • The wording is exact and carefully chosen.
  • There is a brief introduction of traditional RNNs as well.
  • The author takes time to walk slowly through every bit of LSTM.
  • The maths is concise and intuitively explained.
  • The visualizations are excellent.

Just keep in mind that each value should be thought of as a vector (or matrix) of large sizes, but not a simple scalar. However, you don’t need to care much about the specific shape (size) of the values yet, at least during your first read, as doing so would only overly-complicate things.

One amazing point this blog has made is to present LSTMs (and RNNs, in general) in the unrolled representation. Looking at RNNs from this point of view makes things much clearer in many aspects. The Backpropagation-through-time technique we mentioned above is exactly the vanilla backpropagation if we unroll the RNN this way. Just, however, note that in practice, the RNNs are indeed unrolled when training and inferencing, but only in the temporary memory. During the forward pass, the RNN layers are created and stored in memory. When the back-propagated update reaches the first RNN layer, all others but this first layer are removed.

Secondly, there is a great recording of a lecture on RNNs from Stanford University, which they generously publish on Youtube (slides is here). The video lasts 1h13m and we would recommend watching it wholly. The lecturer, Justin Johnson, doesn’t just describe RNNs and LSTM but also gives a bunch of inspiring examples about what this type of network can achieve. Included are: using RNNs on non-sequential inputs, imitating Shakespeare’s writings, auto-generation of Linux source code, and more. If you prefer reading to watching, you may want to check out Karpathy’s blog The Unreasonable Effectiveness of Recurrent Neural Networks, which presents the same insights.

As a side reference, there is a team of three guys from Google who did an interesting experiment to verify the strength of LSTM and GRU (Gated Recurrent Unit, a cousin of LSTM) versus other ten thousand other but similar architectures. Here are the key findings:

  • None of the other architectures consistently beats LSTM in all tasks.
  • The forget gate seems to be the most important component in most tasks.
  • Setting a large bias to the forget gate in the initialization phase greatly improves LSTM’s performance.

References:

  • Colah’s blog on LSTM: blog
  • Stanford University’s lecture on RNNs: video
  • Karpathy’s blog post about the effectiveness of RNNs: blog
  • An empirical exploration of Recurrent Network Architectures, Rafal Jozefowics et al.: paper
  • when does Keras reset an LSTM cell state, StackOverflow: answer

Leave a Reply