The Unreasonable Effectiveness of Next-Token Prediction

The lost glory of RNNs

Sebastian Raschka had a great tweet a few days ago (definitely follow his newsletter for the latest AI news) 

This thought has been bouncing around in my mind recently as well, so this is a good opportunity to unpack it.

What is 'unreasonable effectiveness of RNNs'?

RNN's have existed for a long time. Even LSTMs, one of more complicated flavors of RNNs, were invented way back in 1997. However until early 2000s, all of deep learning was still is relative obscurity. In 2012, CNNs (ImageNet) demolished all previous computer vision benchmarks and deep learning finally arrived in mainstream. RNNs were quickly dusted off the shelves and proved immediately to be much better than CNN's for sequential data such as timeseries and NLP.

In 2015, Andrej Karpathy (then a grad student to eventually become Director of AI at Tesla) published a influential blog post with the title 'The unreasonable effectiveness of RNNs'. Side note : This was one of the early meme titles that spawned dozens of papers and articles which used the format 'The unreasonable _ of _'. The blog post, worth a read even today, detailed the state of the art in RNNs.

One of the results published in 2015 by Andrej was the SOTA result on image captioning - an RNN based model that takes images as input and returns captions describing the image. At the time, RNNs had already achieved SOTA many language tasks such as translation. Two key innovations helped RNNs achieve this.

  1. Attention

  2. Next token prediction

What is next token prediction?

Instead of traditional feature-label pairs, language models can be trained using just a corpus of sentences where the sequence of words was both the feature and the label. This concept has existed for a long time in language modeling and timeseries modelling and is applicable to any ordered sequence. The model takes part of the sequence as input and predicts the next token of the sequence. In fact even non sequential data such as images can be broken into patches which can be read as an ordered sequence.

However by 2015 cumulative improvements in datasets, accelerators and architectural methods reached a turning point where the next token prediction language models started vastly outperforming other models. Since labeled data was not required, the size of the model and dataset was now only gated by the hardware capabilities and time needed to train. Thus billion parameter models and massive language corpus such as entire Wikipedia (or all text on the whole internet) could now be used for training language models.

So are RNNs not unreasonably effective?

In short, no.

2015 was definitely peak RNN. Since then RNNs have fallen off hard and and they are unlikely to make a comeback. Transformers can do everything RNNs could and a lot more. As it turned out, both attention and next token prediction proved to be extremely powerful components of future SOTA transformers such as BERT and GPT-3. Thus in hindsight, the success of RNNs in 2015 was not due to RNNs after at all. In fact since RNNs calculations are sequential and cannot be parallelized as efficiently as attention, transformers and even CNNs, RNNs were probably the gating factor holding back models and datasets from scaling up. Given how important scaling has proven to be, RNN might well be the most ineffective architecture!