27 Days of JAX!

Table of Contents

I’ve been hearing all this talk in ML circles about how Differentiable Programming is supposed to be the next big thing that will supersede our current workflows. Still, me being the sceptic that I am, I stayed away from it since, from what I understood, it’s meant to be mainly a performance boost for matrix computations which can speed up training, but since this isn’t critical to my projects or research, it didn’t seem worth the hassle. Besides, these days it feels like every new paradigm is supposedly able to do magic, at least according to those who evangelize it. However, after letting the curious side of me look into some of the code examples for Taichi (a differentiable programming language compatible with Python), it seemed like there were additional benefits to this technology other than just raw performance increase. Then yesterday, I stumbled upon this tweet:

So, I decided to figure out what differentiable programming is all about and whether its benefits are good (for my use cases) and if there are any drawbacks (I assume many apriori). Thus I’m starting a blog mini-series where I document my journey with JAX.

ML workloads often consist of large, accelerable, pure-and statically-composed subroutines orchestrated by dynamic logic. - Matthew Johnson - [Google Brain]

Day 1: Whetting our apetite for performance!

First, we want to activate our conda environment and then install jax:

conda install -c conda-forge jax

Now, let’s see what is the performance compared to numpy on a simple task:

Carbon code

Okay! It seems like that part is clear enough; JAX is straight-up blowing regular old Numpy out of the water.

Tomorrow: how to vectorize scalar operations?

Author: Carlo Pecora

Created: 2021-12-01 Wed 23:18