AI Features

Machine Learning with JAX

Learn about various machine learning functionalities available in the JAX library.

Taking derivatives with grad()

Computing derivatives in JAX is done using jax.grad.

Python 3.8
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic)
print("Original: ", x_small)
print("Derivative: ", derivative_fn(x_small))

In the code above:

  • Lines 1–3: We apply the @jax.jit decorator to the sum_logistic() function.

  • Line 5: We generate a JAX array of values from zero to five and store it in x_small.

  • Line 6: We use the jax.grad() function to calculate the derivative of the sum_logistics() function with respect to its input. We store the derivative function to the derivative_fn.

  • Lines 7–8: We print the original JAX array, x_small, and the derivative of it using derivative_fn.

The grad function has a has_aux argument that allows us to return auxiliary data. For example, when building machine learning models, we can use it to return loss and gradients.

Python 3.8
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic, has_aux=True)
print("Original: ", x_small)
print("Derivative: ", derivative_fn(x_small))

In the code above:

  • Line 6: We pass a True value to the has_aux argument to make sure that the sum_logistic() function returns the auxiliary data.
  • Line 8: We print the derivative of x_small using derivative_fn. We can see the auxiliary data along with the derivative results in the output.

We can perform advanced automatic differentiation using jax.vjp() and jax.jvp().

Auto-vectorization with vmap

The vmap (vectorizing map) allows us to write a function that can be applied to a single data, and then vmap will map it to a batch of data. Without vmap, the solution would be to loop through the ...

Ask