Dot Product Ranking Metrics#

Metrics for evaluating ranked retrieval where scores are computed as dot products between query and key embeddings. These are useful for dense retrieval and embedding-based recommendation systems where computing all pairwise scores is prohibitive, so only a sampled subset of candidates is evaluated.

class flax_metrics.dot_product_ranking.DotProductMeanAveragePrecision(k: int | None = None)#

Mean Average Precision using dot product scores between query and key embeddings.

See also

This metric is implemented in ir-measures as AP.

Note

The ranked score is computed as query @ keys[indices].T, where query are embeddings with shape (..., num_features) and keys are embeddings with shape (num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated by indices with shape (..., num_sampled). ... indicates batch dimensions that are broadcastable.

Parameters:

k – Number of top items to consider. If None, considers all sampled items.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import DotProductMeanAveragePrecision
>>>
>>> query = jnp.array([1.0, 0.0])
>>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])
>>> indices = jnp.array([0, 1, 2])
>>> relevance = jnp.array([1, 0, 1])
>>> metric = DotProductMeanAveragePrecision()
>>> metric.update(labels=relevance, query=query, keys=keys, indices=indices)
DotProductMeanAveragePrecision(...)
>>> metric.compute()  # (1/1 + 2/3) / 2
Array(0.8333334, dtype=float32)
compute() Array#

Compute and return the mean average precision.

reset() Self#

Reset the metric state in-place.

update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#

Update the mean average precision with a batch of query/key embeddings.

Parameters:
  • labels – Relevance labels for indexed items, shape (*batch_shape, num_sampled).

  • query – Query embeddings, shape (*batch_shape, num_features).

  • keys – Key embeddings for all candidates, shape (num_candidates, num_features).

  • indices – Indices into keys for each query, shape (*batch_shape, num_sampled).

  • mask – Binary mask indicating which queries to include.

class flax_metrics.dot_product_ranking.DotProductMeanReciprocalRank(k: int | None = None)#

Mean Reciprocal Rank using dot product scores between query and key embeddings.

See also

This metric is implemented in ir-measures as RR.

Note

The ranked score is computed as query @ keys[indices].T, where query are embeddings with shape (..., num_features) and keys are embeddings with shape (num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated by indices with shape (..., num_sampled). ... indicates batch dimensions that are broadcastable.

Parameters:

k – Number of top items to consider. If None, considers all sampled items.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import DotProductMeanReciprocalRank
>>>
>>> query = jnp.array([1.0, 0.0])
>>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])
>>> indices = jnp.array([0, 1, 2])
>>> relevance = jnp.array([0, 0, 1])
>>> metric = DotProductMeanReciprocalRank()
>>> metric.update(labels=relevance, query=query, keys=keys, indices=indices)
DotProductMeanReciprocalRank(...)
>>> metric.compute()  # first relevant at rank 3
Array(0.33333334, dtype=float32)
compute() Array#

Compute and return the mean reciprocal rank.

reset() Self#

Reset the metric state in-place.

update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#

Update the mean reciprocal rank with a batch of query/key embeddings.

Parameters:
  • labels – Relevance labels for indexed items, shape (*batch_shape, num_sampled).

  • query – Query embeddings, shape (*batch_shape, num_features).

  • keys – Key embeddings for all candidates, shape (num_candidates, num_features).

  • indices – Indices into keys for each query, shape (*batch_shape, num_sampled).

  • mask – Binary mask indicating which queries to include.

class flax_metrics.dot_product_ranking.DotProductNDCG(k: int | None = None)#

Normalized Discounted Cumulative Gain using dot product scores between query and key embeddings.

See also

This metric is implemented in ir-measures as nDCG.

Note

The ranked score is computed as query @ keys[indices].T, where query are embeddings with shape (..., num_features) and keys are embeddings with shape (num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated by indices with shape (..., num_sampled). ... indicates batch dimensions that are broadcastable.

Parameters:

k – Number of top items to consider. If None, considers all sampled items.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import DotProductNDCG
>>>
>>> query = jnp.array([1.0, 0.0])
>>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])
>>> indices = jnp.array([0, 1, 2])
>>> relevance = jnp.array([1, 3, 2])
>>> metric = DotProductNDCG()
>>> metric.update(labels=relevance, query=query, keys=keys, indices=indices)
DotProductNDCG(...)
>>> metric.compute()  # DCG / IDCG
Array(0.8174..., dtype=float32)
compute() Array#

Compute and return the NDCG.

reset() Self#

Reset the metric state in-place.

update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#

Update the NDCG with a batch of query/key embeddings.

Parameters:
  • labels – Relevance labels for indexed items, shape (*batch_shape, num_sampled).

  • query – Query embeddings, shape (*batch_shape, num_features).

  • keys – Key embeddings for all candidates, shape (num_candidates, num_features).

  • indices – Indices into keys for each query, shape (*batch_shape, num_sampled).

  • mask – Binary mask indicating which queries to include.

class flax_metrics.dot_product_ranking.DotProductPrecisionAtK(k: int)#

Precision@K using dot product scores between query and key embeddings.

See also

This metric is implemented in ir-measures as P.

Note

The ranked score is computed as query @ keys[indices].T, where query are embeddings with shape (..., num_features) and keys are embeddings with shape (num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated by indices with shape (..., num_sampled). ... indicates batch dimensions that are broadcastable.

Parameters:

k – Number of top items to consider.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import DotProductPrecisionAtK
>>>
>>> query = jnp.array([1.0, 0.0])
>>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])
>>> indices = jnp.array([0, 1, 2])
>>> relevance = jnp.array([1, 0, 1])
>>> metric = DotProductPrecisionAtK(k=2)
>>> metric.update(labels=relevance, query=query, keys=keys, indices=indices)
DotProductPrecisionAtK(...)
>>> metric.compute()  # top-2 by score are indices 0 (relevant), 1 (not)
Array(0.5, dtype=float32)
compute() Array#

Compute and return the precision@k.

reset() Self#

Reset the metric state in-place.

update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#

Update the precision@k with a batch of query/key embeddings.

Parameters:
  • labels – Relevance labels for indexed items, shape (*batch_shape, num_sampled).

  • query – Query embeddings, shape (*batch_shape, num_features).

  • keys – Key embeddings for all candidates, shape (num_candidates, num_features).

  • indices – Indices into keys for each query, shape (*batch_shape, num_sampled).

  • mask – Binary mask indicating which queries to include.

class flax_metrics.dot_product_ranking.DotProductRecallAtK(k: int)#

Recall@K using dot product scores between query and key embeddings.

Computes mean recall over all queries (macro-average).

See also

This metric is implemented in ir-measures as R.

Note

The ranked score is computed as query @ keys[indices].T, where query are embeddings with shape (..., num_features) and keys are embeddings with shape (num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated by indices with shape (..., num_sampled). ... indicates batch dimensions that are broadcastable.

Parameters:

k – Number of top items to consider.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import DotProductRecallAtK
>>>
>>> query = jnp.array([1.0, 0.0])
>>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])
>>> indices = jnp.array([0, 1, 2])
>>> relevance = jnp.array([1, 1, 1])
>>> metric = DotProductRecallAtK(k=2)
>>> metric.update(labels=relevance, query=query, keys=keys, indices=indices)
DotProductRecallAtK(...)
>>> metric.compute()  # 2 of 3 relevant items in top-2
Array(0.6666667, dtype=float32)
compute() Array#

Compute and return the recall@k.

reset() Self#

Reset the metric state in-place.

update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#

Update the recall@k with a batch of query/key embeddings.

Parameters:
  • labels – Relevance labels for indexed items, shape (*batch_shape, num_sampled).

  • query – Query embeddings, shape (*batch_shape, num_features).

  • keys – Key embeddings for all candidates, shape (num_candidates, num_features).

  • indices – Indices into keys for each query, shape (*batch_shape, num_sampled).

  • mask – Binary mask indicating which queries to include.