📏 Flax Metrics#

Flax NNX implementation of common metrics.

>>> from flax_metrics import Precision, Recall
>>> from jax import numpy as jnp

>>> labels = jnp.asarray([ 0,  0,  0,  1,  1,  1])
>>> logits = jnp.asarray([-1, -2,  2,  1, -1, -2])

>>> metric = Recall()
>>> metric.update(labels=labels, logits=logits)
Recall(...)
>>> metric.compute()
Array(0.333..., dtype=float32)

Masking#

jax.jit() requires re-compilation for arrays of different shapes, making evaluation on subsets challenging—we cannot index arrays with a mask. Flax Metrics supports masking through the keyword-only argument mask. The example below illustrates that passing mask is equivalent to indexing the input with a binary mask.

>>> mask = jnp.asarray([True, True, True, True, False, True])
>>> metric = Recall()
>>> metric.update(labels=labels, logits=logits, mask=mask)
Recall(...)
>>> metric.compute()
Array(0.5, dtype=float32)

>>> metric.reset()
Recall(...)
>>> metric.update(labels=labels[mask], logits=logits[mask])
Recall(...)
>>> metric.compute()
Array(0.5, dtype=float32)

Available Metrics#

Classification Metrics#

Metrics for evaluating classifiers, operating on logits and binary, categorical, or multinomial labels.

Accuracy

Accuracy metric, the fraction of correct predictions.

Recall

Recall metric, the fraction of actual positives that were correctly identified.

Precision

Precision metric, the fraction of identified positives that are true positives.

F1Score

F1 score, the harmonic mean of precision and recall.

LogProb

Log probability score, the mean likelihood of an outcome.

Regression Metrics#

Metrics for evaluating regression models.

LpError

L-p error metric with optional transform.

MeanAbsoluteError

Mean absolute error (MAE).

MeanSquaredError

Mean squared error (MSE).

RootMeanSquaredError

Root mean squared error (RMSE).

MeanSquaredLogError

Mean squared logarithmic error (MSLE).

RootMeanSquaredLogError

Root mean squared logarithmic error (RMSLE).

Ranking Metrics#

Metrics for evaluating ranked retrieval results using precomputed scores.

PrecisionAtK

Precision@K, the fraction of top-k items that are relevant.

RecallAtK

Recall@K, the fraction of relevant items that appear in the top-k ranked results.

MeanReciprocalRank

Mean Reciprocal Rank.

MeanAveragePrecision

Mean Average Precision.

NDCG

Normalized Discounted Cumulative Gain.

Dot Product Ranking Metrics#

Ranking metrics where scores are computed as dot products between query and key embeddings. Useful for dense retrieval and embedding-based recommendation systems.

DotProductPrecisionAtK

Precision@K using dot product scores between query and key embeddings.

DotProductRecallAtK

Recall@K using dot product scores between query and key embeddings.

DotProductMeanReciprocalRank

Mean Reciprocal Rank using dot product scores between query and key embeddings.

DotProductMeanAveragePrecision

Mean Average Precision using dot product scores between query and key embeddings.

DotProductNDCG

Normalized Discounted Cumulative Gain using dot product scores between query and key embeddings.

Base Metrics#

General-purpose metrics for aggregating values.

Average

Average metric, the arithmetic mean of values.

Statistics

Statistics computed by the Welford metric.

Welford

Welford metric, computing running mean and variance using Welford's algorithm.