What's so special about Transformers?

Transformers were mentioned so many times in my post about CNNs, that I figured I should address the hype directly. Is it just the latest new architecture, like RNNs and CNNs before them? Or is there something special about Transformers?

Are Transformers important?

Yes.

"Attention is all you need" is the seminal paper that unveiled Transformers to the world in 2017. Its impact was felt immediately. For one, every second publication was titled "___ is all you need". Vinay Prabhu did us all a favor and compiled them all in one place.

But more importantly, ML researchers became quickly convinced that Transformers were not just an incremental improvement. Within a few years, Transformers beat all other models in every ML category : vision, NLP, audio/video. Different companies built what Stanford calls "foundation models" : large models trained on huge open ended or unlabeled datasets which can be used as a starting point for finetuned applications.

Many of these large models such as BERT, GPT-3, Dall-E2 and StableDiffusion are already deployed commercially and used by businesses to generate text and images for clients. In 2021/2022 huge droves of researchers from the big FAANG companies have left to form new startups such as Cohere and Adept. Unlike previous AI research startups such as Deepmind, these new startups believe that the time has come for AI to jump from research to practical consumer/industrial applications.

Every single so-called "foundation model" is fully or partially a Transformer-based model. That is a lot of confidence to put in one model architecture.

How do Transformers work?

Many volumes have been written about Transformers. If you don't know much about them, the original Paper - Attention is all you need - and Jay Alammar's greatest hit - The Illustrated Transformer are the best places to start. An excellent primer on Vision Transformers (ViT) can be found on theaisummer.com. Beyond that there is a whole forest of Transformer variants (BERT, GPT, Conformer, Performer), keeping track of them all is nigh impossible. I suspect I will have a lot more to say about specific aspects of Transformers in the future. In this post, I will just focus the key intuitions that make Transformers special.

Intuition 1 : Transformers Transform

Consider the 3 main families of architectures that came before - fully connected (FC), convolutional (CNN) and recurrent (RNN) neural networks. FCs can be expressed as special cases of both CNN and RNN so I will forget them for a while. CNNs and RNNs have completely different intrinsic biases and it is almost impossible for one to do a task suited for another. CNNs can extract generalizable filters from images. RNNs process pieces of the input one-by-one so they are great for sequences. Much progress was made by mixing these families to get hybrid architectures. e.g. CRNNs could learn videos as sequences of images. However their performance varies a lot depending on the precise location, ordering and numbering of each type of layer, making it difficult to build larger and more generalizable models.

Transformers originally came out of sequence-to-sequence work, replacing Attention-based RNNs in NLP tasks. The similarity between Transformers and Attention-RNNs is easy to see : Instead of using 1 recurrent network to process each word of the input sentence one by one, a sequence of identical dense layers process each word of the input sentence in parallel. The loss of the sequence ordering in RNNs is compensated for by adding in the positional embedding which marks the location of each word in the input sequence. Therefore Transformers can learn anything RNNs can.

It took a couple of years for Transformers to be applied to vision successfully. The implementation has more details, but the simple intuition of ViTs in "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" was to split an image up into small patches and then treat each image patch like a word in the image sentence. The position embeddings were used to encode the position of the patch in the image. Experimentally it was shown that this captures CNN-like vision-specific filters.

Therefore, Transformers effectively unified both RNNs and CNNs, able to learn images, language, timeseries and any combinations of these with 1 architecture.

Intuition 2 : Transformers Scale

It is technically possible to replicate a CNN with RNN using a similar method as Transformer - break up the image into patches, use multiple stacked RNNs as different filters. However training this RNN on ImageNet would take a thousand years. RNNs are terrible at scaling w.r.t input size since both the size and the computation time increase linearly with input size. By processing the input words in parallel, Transformers are much faster than RNNs.

CNNs are quite efficient, both in size and computation, w.r.t to the size of the image. However CNNs do that by locking in a strong bias in the field of view (FOV) of the convolution. To get multiple kernel sizes, we would need multiple layers with different kernel sizes like Inception architecture.

Attention layers were slow to be adopted at first because a single attention layer scales quadratically while CNNs scale linearly with the size of the input. Remember attention captures the correlations of every word with every other word, hence quadratic. However even though one attention layer is expensive, scaling the number of attention heads is linear. By using multiple attention heads, we can effectively perform multiple CNNs with different kernel sizes at the same time. In fact it is even better because each kernel size is learned by the model from the data and not locked in before training. Therefore while the number of operations is higher, thanks to more efficient parallelization on GPU/TPUs the actual time taken for attention models scales much better for larger models.

Thus, transformers solves the difficult-to-scale aspects of RNNs (sequential processing) and CNNs (multiple fields of view ) using highly parallelizable multi-headed attention layers.

Combining the two advantages of flexibility and scalability, Transformers allow training massive networks with huge number of parameters thus breaking both size and accuracy records in every domain of machine learning.

————————————————————————————-

There you have it, my intuitions for why Transformers are dominating AI research and why many believe they are the future of deep learning. For more thoughts on AI/ML, subscribe below and follow me on Twitter.