Ranking Metrics#

Metrics for evaluating ranked retrieval results, where items are ranked by precomputed scores. These metrics compare the ranking against relevance labels to measure retrieval quality.

class flax_metrics.ranking.MeanAveragePrecision(k: int | None = None)#

Mean Average Precision.

The mean of average precision scores across queries, where average precision is the sum of precision@k * rel(k) divided by total relevant items.

See also

This metric is implemented in ir-measures as AP.

Parameters:

k – Number of top items to consider. If None, considers all items.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import MeanAveragePrecision
>>>
>>> scores    = jnp.array([0.4, 0.3, 0.2, 0.1])
>>> relevance = jnp.array([  1,   1,   0,   1])
>>> metric = MeanAveragePrecision()
>>> metric.update(labels=relevance, scores=scores)
MeanAveragePrecision(...)
>>> metric.compute()  # (1/1 + 2/2 + 3/4) / 3
Array(0.9166667, dtype=float32)
compute() Array#

Compute and return the mean average precision.

reset() Self#

Reset the metric state in-place.

update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#

Update the mean average precision with a batch of scored items.

Parameters:
  • labels – Relevance labels, shape (..., num_items).

  • scores – Scores for each item, same shape as labels.

  • mask – Binary mask indicating which queries to include.

class flax_metrics.ranking.MeanReciprocalRank(k: int | None = None)#

Mean Reciprocal Rank.

The average of reciprocal ranks of the first relevant item for each query.

See also

This metric is implemented in ir-measures as RR.

Parameters:

k – Number of top items to consider. If None, considers all items.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import MeanReciprocalRank
>>>
>>> scores    = jnp.array([0.1, 0.4, 0.3, 0.2])
>>> relevance = jnp.array([  1,   0,   0,   1])
>>> metric = MeanReciprocalRank()
>>> metric.update(labels=relevance, scores=scores)
MeanReciprocalRank(...)
>>> metric.compute()  # first relevant at rank 3
Array(0.33333334, dtype=float32)
compute() Array#

Compute and return the mean reciprocal rank.

reset() Self#

Reset the metric state in-place.

update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#

Update the mean reciprocal rank with a batch of scored items.

Parameters:
  • labels – Relevance labels, shape (..., num_items).

  • scores – Scores for each item, same shape as labels.

  • mask – Binary mask indicating which queries to include.

class flax_metrics.ranking.NDCG(k: int | None = None)#

Normalized Discounted Cumulative Gain.

See also

This metric is implemented in ir-measures as nDCG.

Parameters:

k – Number of top items to consider. If None, considers all items.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import NDCG
>>>
>>> scores    = jnp.array([0.1, 0.4, 0.3, 0.2])
>>> relevance = jnp.array([  3,   2,   1,   0])
>>> metric = NDCG(k=3)
>>> metric.update(labels=relevance, scores=scores)
NDCG(...)
>>> metric.compute()  # DCG / IDCG
Array(0.5525..., dtype=float32)
compute() Array#

Compute and return the NDCG.

reset() Self#

Reset the metric state in-place.

update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#

Update the NDCG with a batch of scored items.

Parameters:
  • labels – Relevance labels (can be graded), shape (..., num_items).

  • scores – Scores for each item, same shape as labels.

  • mask – Binary mask indicating which queries to include.

class flax_metrics.ranking.PrecisionAtK(k: int)#

Precision@K, the fraction of top-k items that are relevant.

See also

This metric is implemented in ir-measures as P.

Parameters:

k – Number of top items to consider.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import PrecisionAtK
>>>
>>> scores    = jnp.array([0.1, 0.4, 0.3, 0.2])
>>> relevance = jnp.array([  0,   1,   1,   0])
>>> metric = PrecisionAtK(k=2)
>>> metric.update(labels=relevance, scores=scores)
PrecisionAtK(...)
>>> metric.compute()  # top-2 are indices 1, 2 both relevant
Array(1., dtype=float32)
compute() Array#

Compute and return the precision@k.

reset() Self#

Reset the metric state in-place.

update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#

Update the precision@k with a batch of scored items.

Parameters:
  • labels – Relevance labels, shape (..., num_items).

  • scores – Scores for each item, same shape as labels.

  • mask – Binary mask indicating which queries to include.

class flax_metrics.ranking.RecallAtK(k: int)#

Recall@K, the fraction of relevant items that appear in the top-k ranked results.

Computes mean recall over all queries (macro-average).

See also

This metric is implemented in ir-measures as R.

Parameters:

k – Number of top items to consider.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import RecallAtK
>>>
>>> scores    = jnp.array([0.1, 0.4, 0.3, 0.2])
>>> relevance = jnp.array([  1,   1,   1,   0])
>>> metric = RecallAtK(k=2)
>>> metric.update(labels=relevance, scores=scores)
RecallAtK(...)
>>> metric.compute()  # 2 of 3 relevant items in top-2
Array(0.6666667, dtype=float32)
compute() Array#

Compute and return the recall@k.

reset() Self#

Reset the metric state in-place.

update(labels: Array, scores: Array, *, mask: Array | None = None, **_) Self#

Update the recall@k with a batch of scored items.

Parameters:
  • labels – Relevance labels, shape (..., num_items).

  • scores – Scores for each item, same shape as labels.

  • mask – Binary mask indicating which queries to include.