Classification Metrics#
Metrics for evaluating classifiers, including recall, precision, and F1-score. These metrics operate on logits and binary or multinomial labels, applying a threshold to convert logits to point estimates where required.
- class flax_metrics.classification.Accuracy(threshold: float | None = None)#
Accuracy metric, the fraction of correct predictions.
For multi-class classification, the logits are argmax-ed before comparing to labels. For binary classification, pass a
thresholdto determine positive predictions.See also
This metric is implemented in scikit-learn as
sklearn.metrics.accuracy_score().- Parameters:
threshold – For binary classification, logits >= threshold are considered positive. If None (default), multi-class classification is assumed.
Example
Multi-class classification:
>>> from jax import numpy as jnp >>> from flax_metrics import Accuracy >>> >>> logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]) >>> labels = jnp.array([1, 1, 1]) >>> metric = Accuracy() >>> metric.update(logits=logits, labels=labels) Accuracy(...) >>> metric.compute() Array(0.666..., dtype=float32)
Binary classification:
>>> logits = jnp.array([0.6, 0.4, 0.8, 0.3]) >>> labels = jnp.array([1, 1, 1, 0]) >>> metric = Accuracy(threshold=0.5) >>> metric.update(logits=logits, labels=labels) Accuracy(...) >>> metric.compute() Array(0.75, dtype=float32)
- reset() Self#
Reset the metric state.
- update(logits: Array, labels: Array, *, mask: Array | None = None, **_) Self#
Update the metric with a batch of predictions.
- Parameters:
logits – Predicted logits. For multi-class, shape
(..., num_classes). For binary, shape(...,).labels – Ground truth integer labels with shape
(...,).mask – Binary mask indicating which elements to include.
- class flax_metrics.classification.F1Score(threshold: float = 0.0)#
F1 score, the harmonic mean of precision and recall.
See also
This metric is implemented in scikit-learn as
sklearn.metrics.f1_score().- Parameters:
threshold – Threshold for identifying items as positives.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import F1Score >>> >>> labels = jnp.array([ 0, 0, 0, 1, 1, 1, 1]) >>> logits = jnp.array([-1, -1, 1, 1, 1, -1, -1]) >>> metric = F1Score() >>> metric.update(labels=labels, logits=logits) F1Score(...) >>> metric.compute() Array(0.5714286, dtype=float32)
- reset() Self#
Reset the metric state in-place.
- class flax_metrics.classification.LogProb#
Log probability score, the mean likelihood of an outcome.
The metric supports three modes:
Binary classification if the
logitsandlabelshave shape(..., 1).Categorical classification if the inputs have shape
(..., num_classes)and thelabelsare one-hot encoded, i.e.,labels.sum(axis=-1) == 1.Multinomial outcomes if the inputs have shape
(..., num_classes)and thelabelsare many-hot encoded, i.e.,labels.sum(axis=-1) > 1.
Categorical and multinomial outcomes may be mixed within the same batch because multinomial outcomes with one sample are equivalent to categorical outcomes.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import LogProb >>> >>> labels = jnp.array([[ 0, 0, 0, 1, 1, 1, 1]]) >>> logits = jnp.array([[-1, -1, 1, 1, 1, -1, -1]]) >>> metric = LogProb() >>> metric.update(labels=labels, logits=logits) LogProb(...) >>> metric.compute() Array(-5.879968, dtype=float32)
- reset() Self#
Reset the metric state.
- update(labels: Array, logits: Array, *, mask: Array | None = None, **_) Self#
Update the metric with a batch of predictions.
- Parameters:
labels – Ground truth binary labels or multinomial counts with shape
(..., num_classes), where...denotes the batch shape. For binary classification, use labels with shape(..., 1).logits – Predicted logits with shape
(..., num_classes), where...denotes the batch shape. For binary classification, use logits with shape(..., 1).mask – Binary mask indicating which elements to include.
- class flax_metrics.classification.Precision(threshold: float = 0.0)#
Precision metric, the fraction of identified positives that are true positives.
See also
This metric is implemented in scikit-learn as
sklearn.metrics.precision_score().- Parameters:
threshold – Threshold for identifying items as positives.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import Precision >>> >>> labels = jnp.array([ 0, 0, 0, 1, 1, 1, 1]) >>> logits = jnp.array([-1, -1, 1, 1, 1, -1, -1]) >>> metric = Precision() >>> metric.update(labels=labels, logits=logits) Precision(...) >>> metric.compute() Array(0.6666667, dtype=float32)
- reset() Self#
Reset the metric state.
- class flax_metrics.classification.Recall(threshold: float = 0.0)#
Recall metric, the fraction of actual positives that were correctly identified.
See also
This metric is implemented in scikit-learn as
sklearn.metrics.recall_score().- Parameters:
threshold – Threshold for identifying items as positives.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import Recall >>> >>> labels = jnp.array([ 0, 0, 0, 1, 1, 1, 1]) >>> logits = jnp.array([-1, -1, 1, 1, 1, -1, -1]) >>> metric = Recall() >>> metric.update(labels=labels, logits=logits) Recall(...) >>> metric.compute() Array(0.5, dtype=float32)
- reset() Self#
Reset the metric state.