Ranking Metrics#
Metrics for evaluating ranked retrieval results, where items are ranked by precomputed scores. These metrics compare the ranking against relevance labels to measure retrieval quality.
- class flax_metrics.ranking.MeanAveragePrecision(k: int | None = None)#
Mean Average Precision.
The mean of average precision scores across queries, where average precision is the sum of precision@k * rel(k) divided by total relevant items.
See also
This metric is implemented in ir-measures as AP.
- Parameters:
k – Number of top items to consider. If None, considers all items.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import MeanAveragePrecision >>> >>> scores = jnp.array([0.4, 0.3, 0.2, 0.1]) >>> relevance = jnp.array([ 1, 1, 0, 1]) >>> metric = MeanAveragePrecision() >>> metric.update(labels=relevance, scores=scores) MeanAveragePrecision(...) >>> metric.compute() # (1/1 + 2/2 + 3/4) / 3 Array(0.9166667, dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#
Update the mean average precision with a batch of scored items.
- Parameters:
labels – Relevance labels, shape
(..., num_items).scores – Scores for each item, same shape as labels.
mask – Binary mask indicating which queries to include.
- class flax_metrics.ranking.MeanReciprocalRank(k: int | None = None)#
Mean Reciprocal Rank.
The average of reciprocal ranks of the first relevant item for each query.
See also
This metric is implemented in ir-measures as RR.
- Parameters:
k – Number of top items to consider. If None, considers all items.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import MeanReciprocalRank >>> >>> scores = jnp.array([0.1, 0.4, 0.3, 0.2]) >>> relevance = jnp.array([ 1, 0, 0, 1]) >>> metric = MeanReciprocalRank() >>> metric.update(labels=relevance, scores=scores) MeanReciprocalRank(...) >>> metric.compute() # first relevant at rank 3 Array(0.33333334, dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#
Update the mean reciprocal rank with a batch of scored items.
- Parameters:
labels – Relevance labels, shape
(..., num_items).scores – Scores for each item, same shape as labels.
mask – Binary mask indicating which queries to include.
- class flax_metrics.ranking.NDCG(k: int | None = None)#
Normalized Discounted Cumulative Gain.
See also
This metric is implemented in ir-measures as nDCG.
- Parameters:
k – Number of top items to consider. If None, considers all items.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import NDCG >>> >>> scores = jnp.array([0.1, 0.4, 0.3, 0.2]) >>> relevance = jnp.array([ 3, 2, 1, 0]) >>> metric = NDCG(k=3) >>> metric.update(labels=relevance, scores=scores) NDCG(...) >>> metric.compute() # DCG / IDCG Array(0.5525..., dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#
Update the NDCG with a batch of scored items.
- Parameters:
labels – Relevance labels (can be graded), shape
(..., num_items).scores – Scores for each item, same shape as labels.
mask – Binary mask indicating which queries to include.
- class flax_metrics.ranking.PrecisionAtK(k: int)#
Precision@K, the fraction of top-k items that are relevant.
See also
This metric is implemented in ir-measures as P.
- Parameters:
k – Number of top items to consider.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import PrecisionAtK >>> >>> scores = jnp.array([0.1, 0.4, 0.3, 0.2]) >>> relevance = jnp.array([ 0, 1, 1, 0]) >>> metric = PrecisionAtK(k=2) >>> metric.update(labels=relevance, scores=scores) PrecisionAtK(...) >>> metric.compute() # top-2 are indices 1, 2 both relevant Array(1., dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#
Update the precision@k with a batch of scored items.
- Parameters:
labels – Relevance labels, shape
(..., num_items).scores – Scores for each item, same shape as labels.
mask – Binary mask indicating which queries to include.
- class flax_metrics.ranking.RecallAtK(k: int)#
Recall@K, the fraction of relevant items that appear in the top-k ranked results.
Computes mean recall over all queries (macro-average).
See also
This metric is implemented in ir-measures as R.
- Parameters:
k – Number of top items to consider.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import RecallAtK >>> >>> scores = jnp.array([0.1, 0.4, 0.3, 0.2]) >>> relevance = jnp.array([ 1, 1, 1, 0]) >>> metric = RecallAtK(k=2) >>> metric.update(labels=relevance, scores=scores) RecallAtK(...) >>> metric.compute() # 2 of 3 relevant items in top-2 Array(0.6666667, dtype=float32)
- reset() Self#
Reset the metric state in-place.