AI Features

Computing and Monitoring Loss in JAX

Learn about various ways to calculate and monitor loss in models using JAX.

Computing loss with JAX metrics

JAX metrics is an open-source package for computing losses and metrics in JAX. It provides a Keras-like API for computing model loss and metrics. For example, here is how we use the library to compute the cross entropy loss.

Python 3.8
import jax_metrics as jm
crossentropy = jm.losses.Crossentropy()
logits = jnp.array([0.50, 0.60, 0.70, 0.30, 0.25])
labels = jnp.array([0.0, 1.0, 1.0, 0.0, 0.0])
print(crossentropy(target=labels, preds=logits))
print(jm.losses.crossentropy(target=labels, preds=logits))

In the code above:

  • Line 1: We import the jax_metrics library to calculate the cross entropy loss.

  • Line 2: We create an instance of the Crossentropy() loss function to compute the loss.

  • Lines 4–5: We define two JAX arrays: logits and labels.

  • Line 7: We compute the cross entropy by calling the crossentropy() function.

  • Line 8: We compute the cross entropy by calling the jm.losses.crossentropy() method. It is an alternative syntax to compute the loss. ...

Ask