Introduction to JAX Data Types and Array
Learn about the important datatypes and arrays in JAX.
We'll cover the following
Overview
JAX is a Python library offering high performance in machine learning with
Automatic differentiation
Vectorization
JIT compilation
Data types in JAX
The data types in NumPy are similar to those in JAX arrays. For instance, here is how we can create float
and int
data in JAX.
import jax.numpy as jnpx = jnp.float32(1.25844)print("x :", x)y = jnp.int32(45.25844)print("y :", y)
In the code above, we import the JAX version of NumPy and name it jnp
. We define two JAX variables, x
and y
, of types float32
and int32
, respectively. Lastly, we print the values of both variables.
When we check the type of the data, we will see that it’s a DeviceArray
. In the code below, we can see the same type for both float32
and int32
variables.
import jax.numpy as jnpx = jnp.float32(1.25844)print("type of x: ",type(x))y = jnp.int32(45.25844)print("type of y: ",type(y))
The DeviceArray
in JAX is the equivalent of numpy.ndarry
in NumPy, and jax.numpy
provides an interface similar to NumPy’s. However, JAX also provides jax.lax
, a low-level API that is more powerful and stricter. For example, with jax.numpy
, we can add numbers that have mixed types, but jax.lax
will not allow this.
Ways to create JAX arrays
We can create JAX arrays like we would in NumPy. For example, we can use:
- The
arange()
function - The
linspace()
function - Python lists
- The
zeros()
function - The
ones()
function - The
identity()
oreye()
function
Let’s look at the outputs of the functions above:
import jax.numpy as jnpa = jnp.arange(10)print("a : ", a)b = jnp.linspace(0, 10, 30)print("b :", b)scores = [50,60,70,30,25,70]scores_array = jnp.array(scores)print("scores_array :", scores_array)c = jnp.zeros(5)print("c :", c)d = jnp.ones(5)print("d :", d)e = jnp.eye(5)print("e :", e)f = jnp.identity(5)print("f :", f)
Let’s understand the code above:
Line 3: We call the
jnp.arange()
method that generates the JAX array of10
elements from 0 to 9.Line 6: We call the
jnp.linspace()
method that creates a JAX array of30
values that are linearly distributed between 0 to 10. By default, thelinspace()
method generates50
values. We can generate any number of values in a given range.Lines 9–10: We define a Python list,
scores
, and use thejnp.array()
method to convert thescores
into a JAX array.Line 13: We call the
jnp.zeros()
method to generate the JAX array of5
zero values.Line 16: Similarly, we call the
jnp.ones()
method to generate the JAX array of5
one values.Line 19: We create an
ofidentity matrix An identity matrix is a square matrix where diagonal values are one, and all other elements are zero. by calling the jnp.eye()
method.Line 22: Just like the
jnp.eye()
method, we can also generate an identity matrix with thejnp.identity()
method.