Classification Metrics#

Metrics for evaluating classifiers, including recall, precision, and F1-score. These metrics operate on logits and binary or multinomial labels, applying a threshold to convert logits to point estimates where required.

class flax_metrics.classification.Accuracy(threshold: float | None = None)#

Accuracy metric, the fraction of correct predictions.

For multi-class classification, the logits are argmax-ed before comparing to labels. For binary classification, pass a threshold to determine positive predictions.

See also

This metric is implemented in scikit-learn as sklearn.metrics.accuracy_score().

Parameters:

threshold – For binary classification, logits >= threshold are considered positive. If None (default), multi-class classification is assumed.

Example

Multi-class classification:

>>> from jax import numpy as jnp
>>> from flax_metrics import Accuracy
>>>
>>> logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
>>> labels = jnp.array([1, 1, 1])
>>> metric = Accuracy()
>>> metric.update(logits=logits, labels=labels)
Accuracy(...)
>>> metric.compute()
Array(0.666..., dtype=float32)

Binary classification:

>>> logits = jnp.array([0.6, 0.4, 0.8, 0.3])
>>> labels = jnp.array([1, 1, 1, 0])
>>> metric = Accuracy(threshold=0.5)
>>> metric.update(logits=logits, labels=labels)
Accuracy(...)
>>> metric.compute()
Array(0.75, dtype=float32)
compute() Array#

Compute and return the accuracy.

reset() Self#

Reset the metric state.

update(logits: Array, labels: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • logits – Predicted logits. For multi-class, shape (..., num_classes). For binary, shape (...,).

  • labels – Ground truth integer labels with shape (...,).

  • mask – Binary mask indicating which elements to include.

class flax_metrics.classification.F1Score(threshold: float = 0.0)#

F1 score, the harmonic mean of precision and recall.

See also

This metric is implemented in scikit-learn as sklearn.metrics.f1_score().

Parameters:

threshold – Threshold for identifying items as positives.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import F1Score
>>>
>>> labels = jnp.array([ 0,  0,  0,  1,  1,  1,  1])
>>> logits = jnp.array([-1, -1,  1,  1,  1, -1, -1])
>>> metric = F1Score()
>>> metric.update(labels=labels, logits=logits)
F1Score(...)
>>> metric.compute()
Array(0.5714286, dtype=float32)
compute() Array#

Compute and return the F1 score.

reset() Self#

Reset the metric state in-place.

update(labels: Array, logits: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • labels – Ground truth binary labels.

  • logits – Predicted logits.

  • mask – Binary mask indicating which elements to include.

class flax_metrics.classification.LogProb#

Log probability score, the mean likelihood of an outcome.

The metric supports three modes:

  1. Binary classification if the logits and labels have shape (..., 1).

  2. Categorical classification if the inputs have shape (..., num_classes) and the labels are one-hot encoded, i.e., labels.sum(axis=-1) == 1.

  3. Multinomial outcomes if the inputs have shape (..., num_classes) and the labels are many-hot encoded, i.e., labels.sum(axis=-1) > 1.

Categorical and multinomial outcomes may be mixed within the same batch because multinomial outcomes with one sample are equivalent to categorical outcomes.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import LogProb
>>>
>>> labels = jnp.array([[ 0,  0,  0,  1,  1,  1,  1]])
>>> logits = jnp.array([[-1, -1,  1,  1,  1, -1, -1]])
>>> metric = LogProb()
>>> metric.update(labels=labels, logits=logits)
LogProb(...)
>>> metric.compute()
Array(-5.879968, dtype=float32)
compute() Array#

Compute and return the mean log probability score.

reset() Self#

Reset the metric state.

update(labels: Array, logits: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • labels – Ground truth binary labels or multinomial counts with shape (..., num_classes), where ... denotes the batch shape. For binary classification, use labels with shape (..., 1).

  • logits – Predicted logits with shape (..., num_classes), where ... denotes the batch shape. For binary classification, use logits with shape (..., 1).

  • mask – Binary mask indicating which elements to include.

class flax_metrics.classification.Precision(threshold: float = 0.0)#

Precision metric, the fraction of identified positives that are true positives.

See also

This metric is implemented in scikit-learn as sklearn.metrics.precision_score().

Parameters:

threshold – Threshold for identifying items as positives.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import Precision
>>>
>>> labels = jnp.array([ 0,  0,  0,  1,  1,  1,  1])
>>> logits = jnp.array([-1, -1,  1,  1,  1, -1, -1])
>>> metric = Precision()
>>> metric.update(labels=labels, logits=logits)
Precision(...)
>>> metric.compute()
Array(0.6666667, dtype=float32)
compute() Array#

Compute and return the precision.

reset() Self#

Reset the metric state.

update(labels: Array, logits: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • labels – Ground truth binary labels.

  • logits – Predicted logits.

  • mask – Binary mask indicating which elements to include.

class flax_metrics.classification.Recall(threshold: float = 0.0)#

Recall metric, the fraction of actual positives that were correctly identified.

See also

This metric is implemented in scikit-learn as sklearn.metrics.recall_score().

Parameters:

threshold – Threshold for identifying items as positives.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import Recall
>>>
>>> labels = jnp.array([ 0,  0,  0,  1,  1,  1,  1])
>>> logits = jnp.array([-1, -1,  1,  1,  1, -1, -1])
>>> metric = Recall()
>>> metric.update(labels=labels, logits=logits)
Recall(...)
>>> metric.compute()
Array(0.5, dtype=float32)
compute() Array#

Compute and return the recall.

reset() Self#

Reset the metric state.

update(labels: Array, logits: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • labels – Ground truth binary labels.

  • logits – Predicted logits.

  • mask – Binary mask indicating which elements to include.