📏 Flax Metrics#
Flax NNX implementation of common metrics.
>>> from flax_metrics import Precision, Recall
>>> from jax import numpy as jnp
>>> labels = jnp.asarray([ 0, 0, 0, 1, 1, 1])
>>> logits = jnp.asarray([-1, -2, 2, 1, -1, -2])
>>> metric = Recall()
>>> metric.update(labels=labels, logits=logits)
Recall(...)
>>> metric.compute()
Array(0.333..., dtype=float32)
Masking#
jax.jit() requires re-compilation for arrays of different shapes, making evaluation on subsets challenging—we cannot index arrays with a mask. Flax Metrics supports masking through the keyword-only argument mask. The example below illustrates that passing mask is equivalent to indexing the input with a binary mask.
>>> mask = jnp.asarray([True, True, True, True, False, True])
>>> metric = Recall()
>>> metric.update(labels=labels, logits=logits, mask=mask)
Recall(...)
>>> metric.compute()
Array(0.5, dtype=float32)
>>> metric.reset()
Recall(...)
>>> metric.update(labels=labels[mask], logits=logits[mask])
Recall(...)
>>> metric.compute()
Array(0.5, dtype=float32)
Available Metrics#
Classification Metrics#
Metrics for evaluating classifiers, operating on logits and binary, categorical, or multinomial labels.
Accuracy metric, the fraction of correct predictions. |
|
Recall metric, the fraction of actual positives that were correctly identified. |
|
Precision metric, the fraction of identified positives that are true positives. |
|
F1 score, the harmonic mean of precision and recall. |
|
Log probability score, the mean likelihood of an outcome. |
Regression Metrics#
Metrics for evaluating regression models.
L-p error metric with optional transform. |
|
Mean absolute error (MAE). |
|
Mean squared error (MSE). |
|
Root mean squared error (RMSE). |
|
Mean squared logarithmic error (MSLE). |
|
Root mean squared logarithmic error (RMSLE). |
Ranking Metrics#
Metrics for evaluating ranked retrieval results using precomputed scores.
Precision@K, the fraction of top-k items that are relevant. |
|
Recall@K, the fraction of relevant items that appear in the top-k ranked results. |
|
Mean Reciprocal Rank. |
|
Mean Average Precision. |
|
Normalized Discounted Cumulative Gain. |
Dot Product Ranking Metrics#
Ranking metrics where scores are computed as dot products between query and key embeddings. Useful for dense retrieval and embedding-based recommendation systems.
Precision@K using dot product scores between query and key embeddings. |
|
Recall@K using dot product scores between query and key embeddings. |
|
Mean Reciprocal Rank using dot product scores between query and key embeddings. |
|
Mean Average Precision using dot product scores between query and key embeddings. |
|
Normalized Discounted Cumulative Gain using dot product scores between query and key embeddings. |
Base Metrics#
General-purpose metrics for aggregating values.
Average metric, the arithmetic mean of values. |
|
Statistics computed by the Welford metric. |
|
Welford metric, computing running mean and variance using Welford's algorithm. |