Machine Learning with JAX
Learn about various machine learning functionalities available in the JAX library.
Taking derivatives with grad()
Computing derivatives in JAX is done using jax.grad.
@jax.jitdef sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(6.)derivative_fn = jax.grad(sum_logistic)print("Original: ", x_small)print("Derivative: ", derivative_fn(x_small))
In the code above:
Lines 1–3: We apply the
@jax.jitdecorator to thesum_logistic()function.Line 5: We generate a JAX array of values from zero to five and store it in
x_small.Line 6: We use the
jax.grad()function to calculate the derivative of thesum_logistics()function with respect to its input. We store the derivative function to thederivative_fn.Lines 7–8: We print the original JAX array,
x_small, and the derivative of it usingderivative_fn.
The grad function has a has_aux argument that allows us to return auxiliary data. For example, when building machine learning models, we can use it to return loss and gradients.
@jax.jitdef sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)x_small = jnp.arange(6.)derivative_fn = jax.grad(sum_logistic, has_aux=True)print("Original: ", x_small)print("Derivative: ", derivative_fn(x_small))
In the code above:
- Line 6: We pass a
Truevalue to thehas_auxargument to make sure that thesum_logistic()function returns the auxiliary data. - Line 8: We print the derivative of
x_smallusingderivative_fn. We can see the auxiliary data along with the derivative results in the output.
We can perform advanced automatic differentiation using jax.vjp() and jax.jvp().
Auto-vectorization with vmap
The vmap (vectorizing map) allows us to write a function that can be applied to a single data, and then vmap will map it to a batch of data. Without vmap, the solution would be to loop through the ...