DZone
Thanks for visiting DZone today,
Edit Profile
  • Manage Email Subscriptions
  • How to Post to DZone
  • Article Submission Guidelines
Sign Out View Profile
  • Post an Article
  • Manage My Drafts
Over 2 million developers have joined DZone.
Log In / Join
Please enter at least three characters to search
Refcards Trend Reports
Events Video Library
Refcards
Trend Reports

Events

View Events Video Library

Zones

Culture and Methodologies Agile Career Development Methodologies Team Management
Data Engineering AI/ML Big Data Data Databases IoT
Software Design and Architecture Cloud Architecture Containers Integration Microservices Performance Security
Coding Frameworks Java JavaScript Languages Tools
Testing, Deployment, and Maintenance Deployment DevOps and CI/CD Maintenance Monitoring and Observability Testing, Tools, and Frameworks
Culture and Methodologies
Agile Career Development Methodologies Team Management
Data Engineering
AI/ML Big Data Data Databases IoT
Software Design and Architecture
Cloud Architecture Containers Integration Microservices Performance Security
Coding
Frameworks Java JavaScript Languages Tools
Testing, Deployment, and Maintenance
Deployment DevOps and CI/CD Maintenance Monitoring and Observability Testing, Tools, and Frameworks

The software you build is only as secure as the code that powers it. Learn how malicious code creeps into your software supply chain.

Apache Cassandra combines the benefits of major NoSQL databases to support data management needs not covered by traditional RDBMS vendors.

Generative AI has transformed nearly every industry. How can you leverage GenAI to improve your productivity and efficiency?

Modernize your data layer. Learn how to design cloud-native database architectures to meet the evolving demands of AI and GenAI workloads.

Related

  • Facial Recognition and Identification in Computer Vision
  • How to Use Python for Data Science
  • Python Packages for Data Science
  • AI's Dilemma: When to Retrain and When to Unlearn?

Trending

  • Navigating the LLM Landscape: A Comparative Analysis of Leading Large Language Models
  • How AI Is Changing the Way Developers Write Code
  • How Large Tech Companies Architect Resilient Systems for Millions of Users
  • Metrics at a Glance for Production Clusters
  1. DZone
  2. Data Engineering
  3. AI/ML
  4. Intro to JAX for Machine Learning

Intro to JAX for Machine Learning

JAX is a Python library designed for high-performance numerical computing, especially machine learning research. It accelerates Python and NumPy code with the use of GPU.

By 
Kevin Vu user avatar
Kevin Vu
·
Dec. 14, 22 · Tutorial
Likes (3)
Comment
Save
Tweet
Share
3.4K Views

Join the DZone community and get the full member experience.

Join For Free

JAX is up and coming in the Machine Learning space with ambitions to make machine learning simple yet efficient. JAX is still a Google and Deepmind research project and not yet an official Google product but has been used extensively internally and adopted by external ML researchers. We wanted to offer an introduction to JAX, how to install JAX, and its advantages and capabilities.

What Is JAX for Machine Learning?

JAX is a Python library designed for high-performance numerical computing, especially machine learning research. Its API for numerical functions is based on NumPy, a collection of functions used in scientific computing. JAX focuses on accelerating the machine learning process by using XLA to compile NumPy functions on GPUs and uses autograd to differentiate Python and NumPy functions as well as gradient-based optimization. JAX is able to differentiate through loops, branches, recursion, and closures, and take derivatives of derivatives of derivatives with ease using GPU acceleration. JAX also supports backpropagation and forward-mode differentiation.

JAX offers superior performance when using GPUs to run your code and a just-in-time (JIT) compilation option to easily speed up large projects, which we will delve into later in this article. 

Think of JAX as a Python Library that modifies NumPy and Python code with function transformations to enable accelerated machine learning. As a general rule, you should use JAX whenever planning to train with GPUs, compute gradients (autograd), or use JIT code compiling.

Why Use JAX?

In addition to working with normal CPUs, JAX's main function is the capability to be fully functional with different processing units such as GPUs. This gives JAX a great advantage over similar packages because the use of GPU parallelization enables faster performance than CPUs when it comes to image and vector processing. 

This is extremely important because when using the NumPy library users can build matrices of exceptional sizes allowing GPUs to be much more time-efficient when processing such data formats. 

This time difference enables the JAX library to exceed NumPy alone by over 100 times the speed and performance through a couple of key implementations:

  • Vectorization - processing multiple data as single instruction that provides great speedups for linear algebra computations and machine learning
  • Code Parallelization - the process of taking serial code that runs on a single processor and distributing it. GPUs are preferred here since they have numerous processors specialized for computations.
  • Automatic Differentiation - very simple and straightforward differentiation that can be chained multiple times to evaluate higher-order derivatives with ease.

How to Install JAX

To install the CPU-only version of JAX, which might be useful for doing local development on a laptop, you can run

Shell
 
pip install --upgrade pip
pip install --upgrade "jax[cpu]"


On Linux, it is often necessary to first update pip to a version that supports manylinux2014 wheels.

pip installation: GPU (CUDA)

To install JAX with both CPU and NVIDIA GPU support, you must first install CUDA and CuDNN, if they haven’t already been installed. Unlike many other popular deep learning systems, JAX does not bundle CUDA or CuDNN as part of the pip package.

JAX provides pre-built CUDA-compatible wheels for Linux only, with CUDA 11.1 or newer, and CuDNN 8.0.5 or newer. Other combinations of the operating system, CUDA, and CuDNN are possible, but require building from the source.

  • CUDA 11.1 or newer is required
    • You may be able to use older CUDA versions if you build from the source, but there are known bugs in CUDA in all CUDA versions older than 11.1, so we do not ship prebuilt binaries for older CUDA versions.
  • The supported cuDNN versions for the prebuilt wheels are:
    • cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough since it supports additional functionality.
    • cuDNN 8.0.5 or newer.
  • You must use an NVIDIA driver version that is at least as new as your CUDA toolkit's corresponding driver version. For example, if you have CUDA 11.4 update 4 installed, you must use NVIDIA driver 470.82.01 or newer if on Linux. This is a strict requirement that exists because JAX relies on JIT-compiling code; older drivers may lead to failures.
    • If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.
Shell
 
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels are only available on Linux.
pip install --upgrade "jax[cuda]" https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


The jaxlib version must correspond to the version of the existing CUDA installation you want to use. You can specify a particular CUDA and CuDNN version for jaxlib explicitly:

Shell
 
pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


You can find your CUDA version with the command:

Shell
 
nvcc --version


Some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-11.1). If CUDA is installed elsewhere on your system, you can either create a symlink:

Shell
 
sudo ln -s /path/to/cuda /usr/local/cuda-X.X


Comparing JAX to NumPy

Since JAX is an augmented NumPy, their syntax is very similar, giving users the ability to use the two interchangeably in projects where NumPy or JAX isn’t performing. This is often with smaller projects where the amount of acceleration is negligible in time saved. However, as models get larger, the more you should consider JAX.

Comparing JAX to NumPy

Multiplying Two Matrices Using JAX vs. NumPy

To clearly illustrate the speed difference between these two libraries, we will use both to multiply two matrices by each other and then check the performance differences between CPU only and GPU. We will also check the performance boost that is caused by the JIT compiler. 

To follow along with this tutorial, install and import the JAX and NumPy libraries (from the previous step). You can test your code on sites such as Kaggle or Google Colab. As with any library, you should import JAX by writing the following lines at the beginning of your code:

Python
 
import jax.numpy as jnp
from jax import random


You can also import the NumPy library in a similar manner:

Python
 
import numpy as np


Next, we will compare the performance of both JAX and Numpy using the CPU and GPU by multiplying two matrices together in Python. For these benchmarks, lower is better.

NumPy on CPU

To begin, we will create a matrix of 5,000 by 5,000 using NumPy and test its performance speed-wise.

Python
 
import numpy as np

size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)


785 ms per loop

A single loop of the code running on NumPy took around 750 ms per loop to run.

JAX on CPU

Now let’s run the same code,  but this time using the JAX library.

Python
 
import jax.numpy as jnp

size = 5000
x = jnp.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()


1.43 sec per loop

As you can see, comparing JAX and NumPy CPU-only performance shows that NumPy is the faster option. While JAX may not provide the best performance with normal CPUs, it does provide much better performance with GPUs.

JAX With GPU

Now, let's try to create the same 5,000 by 5,000 matrix, this time using JAX with a GPU instead of the regular CPU:

Python
 
import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
size = 5000

x = random.normal(key, (size, size)).astype(jnp.float32)
%time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()


80.6 ms per loop

As clearly shown when running JAX on a GPU instead of a CPU, we achieve a much better time of around 80ms per loop (around 15 times the performance). This will be even easier to see when using larger matrices or time scales.

Just-in-Time Compilation (JIT)

Using the jit command, our code will be compiled using a specific XLA compiler, allowing our functions to be efficiently executed.

XLA, short for accelerated linear algebra, is used by libraries such as JAX and Tensorflow to compile and run code on the GPU with greater efficiency. So to sum it up, XLA is a specific linear algebra compiler that is capable of compiling code at a much higher speed.

We will test our code using the selu_np function, which stands for Scaled Exponential Linear Unit, and check the different time performances between NumPy on a normal CPU, and running JAX on a GPU with JIT.

Python
 
def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


NumPy on CPU 

To start with, we will create a vector of size 1,000,000 using the NumPy library.

Python
 
import numpy as np

x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)


8.3  ms per loop

JAX on GPU with JIT

Now we will test our code while using JAX and JIT on a GPU.

Python
 
import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit

key = random.PRNGKey(0)

def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x) 
%time selu_jax_jit(x_jax).block_until_ready() 
%timeit selu_jax_jit(x_jax).block_until_ready() 


153 µs per loop (0.153 milisecond per loop)

Lastly, when using the JIT compiler with a GPU, we get a much better performance than using a normal GPU. As you can clearly, see the difference is very apparent, a nearly 5000% speed increase or 50 times faster from NumPy to JAX with JIT!

Think of JAX as a modification to NumPy to enable accelerated machine learning with GPUs. Since NumPy can only be compiled CPU, JAX is faster than NumPy if you opt to execute code on GPUs. As a general rule, you should use JAX whenever planning to use NumPy with GPUs or use JIT code compiling.

Note: to check the original article from where the examples in this tutorial were used, check the following link: original code.

JAX Limitation: Pure Functions

JAX transformations and complications are designed for Python functions that are functionally pure. Pure functions cannot change the state of the program by accessing outside variables, and cannot have side effects on functions such as input/output streams like print().

Consecutive runs cause these side effects to not perform as intended. If you are not careful, untracked side effects could throw off the accuracy of your intended computations.

Using Google’s JAX 

In this article, we explained the capabilities of JAX and what advantages it brings to NumPy. We covered how to install the JAX library and its advantages for machine learning.

We then went on to import JAX and NumPy. Moreover, we compared JAX with NumPy (which is the most well-known competitor library out there) and revealed the time and performance differences between these two using regular CPUs and GPUs alongside some JIT tests as well and saw drastic speed improvements.

If you are an advanced machine/deep learning practitioner, then adding a library such as JAX to your arsenal with its (GPU/TPU) accelerators and its efficient JIT compiler will definitely make life much easier.

CUDA Deep learning Machine learning NumPy Python (language)

Published at DZone with permission of Kevin Vu. See the original article here.

Opinions expressed by DZone contributors are their own.

Related

  • Facial Recognition and Identification in Computer Vision
  • How to Use Python for Data Science
  • Python Packages for Data Science
  • AI's Dilemma: When to Retrain and When to Unlearn?

Partner Resources

×

Comments
Oops! Something Went Wrong

The likes didn't load as expected. Please refresh the page and try again.

ABOUT US

  • About DZone
  • Support and feedback
  • Community research
  • Sitemap

ADVERTISE

  • Advertise with DZone

CONTRIBUTE ON DZONE

  • Article Submission Guidelines
  • Become a Contributor
  • Core Program
  • Visit the Writers' Zone

LEGAL

  • Terms of Service
  • Privacy Policy

CONTACT US

  • 3343 Perimeter Hill Drive
  • Suite 100
  • Nashville, TN 37211
  • support@dzone.com

Let's be friends:

Likes
There are no likes...yet! 👀
Be the first to like this post!
It looks like you're not logged in.
Sign in to see who liked this post!