Dot Product Ranking Metrics#
Metrics for evaluating ranked retrieval where scores are computed as dot products between query and key embeddings. These are useful for dense retrieval and embedding-based recommendation systems where computing all pairwise scores is prohibitive, so only a sampled subset of candidates is evaluated.
- class flax_metrics.dot_product_ranking.DotProductMeanAveragePrecision(k: int | None = None)#
Mean Average Precision using dot product scores between query and key embeddings.
See also
This metric is implemented in ir-measures as AP.
Note
The ranked score is computed as
query @ keys[indices].T, wherequeryare embeddings with shape(..., num_features)andkeysare embeddings with shape(num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated byindiceswith shape(..., num_sampled)....indicates batch dimensions that are broadcastable.- Parameters:
k – Number of top items to consider. If None, considers all sampled items.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import DotProductMeanAveragePrecision >>> >>> query = jnp.array([1.0, 0.0]) >>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]) >>> indices = jnp.array([0, 1, 2]) >>> relevance = jnp.array([1, 0, 1]) >>> metric = DotProductMeanAveragePrecision() >>> metric.update(labels=relevance, query=query, keys=keys, indices=indices) DotProductMeanAveragePrecision(...) >>> metric.compute() # (1/1 + 2/3) / 2 Array(0.8333334, dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#
Update the mean average precision with a batch of query/key embeddings.
- Parameters:
labels – Relevance labels for indexed items, shape
(*batch_shape, num_sampled).query – Query embeddings, shape
(*batch_shape, num_features).keys – Key embeddings for all candidates, shape
(num_candidates, num_features).indices – Indices into keys for each query, shape
(*batch_shape, num_sampled).mask – Binary mask indicating which queries to include.
- class flax_metrics.dot_product_ranking.DotProductMeanReciprocalRank(k: int | None = None)#
Mean Reciprocal Rank using dot product scores between query and key embeddings.
See also
This metric is implemented in ir-measures as RR.
Note
The ranked score is computed as
query @ keys[indices].T, wherequeryare embeddings with shape(..., num_features)andkeysare embeddings with shape(num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated byindiceswith shape(..., num_sampled)....indicates batch dimensions that are broadcastable.- Parameters:
k – Number of top items to consider. If None, considers all sampled items.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import DotProductMeanReciprocalRank >>> >>> query = jnp.array([1.0, 0.0]) >>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]) >>> indices = jnp.array([0, 1, 2]) >>> relevance = jnp.array([0, 0, 1]) >>> metric = DotProductMeanReciprocalRank() >>> metric.update(labels=relevance, query=query, keys=keys, indices=indices) DotProductMeanReciprocalRank(...) >>> metric.compute() # first relevant at rank 3 Array(0.33333334, dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#
Update the mean reciprocal rank with a batch of query/key embeddings.
- Parameters:
labels – Relevance labels for indexed items, shape
(*batch_shape, num_sampled).query – Query embeddings, shape
(*batch_shape, num_features).keys – Key embeddings for all candidates, shape
(num_candidates, num_features).indices – Indices into keys for each query, shape
(*batch_shape, num_sampled).mask – Binary mask indicating which queries to include.
- class flax_metrics.dot_product_ranking.DotProductNDCG(k: int | None = None)#
Normalized Discounted Cumulative Gain using dot product scores between query and key embeddings.
See also
This metric is implemented in ir-measures as nDCG.
Note
The ranked score is computed as
query @ keys[indices].T, wherequeryare embeddings with shape(..., num_features)andkeysare embeddings with shape(num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated byindiceswith shape(..., num_sampled)....indicates batch dimensions that are broadcastable.- Parameters:
k – Number of top items to consider. If None, considers all sampled items.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import DotProductNDCG >>> >>> query = jnp.array([1.0, 0.0]) >>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]) >>> indices = jnp.array([0, 1, 2]) >>> relevance = jnp.array([1, 3, 2]) >>> metric = DotProductNDCG() >>> metric.update(labels=relevance, query=query, keys=keys, indices=indices) DotProductNDCG(...) >>> metric.compute() # DCG / IDCG Array(0.8174..., dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#
Update the NDCG with a batch of query/key embeddings.
- Parameters:
labels – Relevance labels for indexed items, shape
(*batch_shape, num_sampled).query – Query embeddings, shape
(*batch_shape, num_features).keys – Key embeddings for all candidates, shape
(num_candidates, num_features).indices – Indices into keys for each query, shape
(*batch_shape, num_sampled).mask – Binary mask indicating which queries to include.
- class flax_metrics.dot_product_ranking.DotProductPrecisionAtK(k: int)#
Precision@K using dot product scores between query and key embeddings.
See also
This metric is implemented in ir-measures as P.
Note
The ranked score is computed as
query @ keys[indices].T, wherequeryare embeddings with shape(..., num_features)andkeysare embeddings with shape(num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated byindiceswith shape(..., num_sampled)....indicates batch dimensions that are broadcastable.- Parameters:
k – Number of top items to consider.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import DotProductPrecisionAtK >>> >>> query = jnp.array([1.0, 0.0]) >>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]) >>> indices = jnp.array([0, 1, 2]) >>> relevance = jnp.array([1, 0, 1]) >>> metric = DotProductPrecisionAtK(k=2) >>> metric.update(labels=relevance, query=query, keys=keys, indices=indices) DotProductPrecisionAtK(...) >>> metric.compute() # top-2 by score are indices 0 (relevant), 1 (not) Array(0.5, dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#
Update the precision@k with a batch of query/key embeddings.
- Parameters:
labels – Relevance labels for indexed items, shape
(*batch_shape, num_sampled).query – Query embeddings, shape
(*batch_shape, num_features).keys – Key embeddings for all candidates, shape
(num_candidates, num_features).indices – Indices into keys for each query, shape
(*batch_shape, num_sampled).mask – Binary mask indicating which queries to include.
- class flax_metrics.dot_product_ranking.DotProductRecallAtK(k: int)#
Recall@K using dot product scores between query and key embeddings.
Computes mean recall over all queries (macro-average).
See also
This metric is implemented in ir-measures as R.
Note
The ranked score is computed as
query @ keys[indices].T, wherequeryare embeddings with shape(..., num_features)andkeysare embeddings with shape(num_candidates, num_features). When the number of candidates is large, we only consider a subset of them, indicated byindiceswith shape(..., num_sampled)....indicates batch dimensions that are broadcastable.- Parameters:
k – Number of top items to consider.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import DotProductRecallAtK >>> >>> query = jnp.array([1.0, 0.0]) >>> keys = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]) >>> indices = jnp.array([0, 1, 2]) >>> relevance = jnp.array([1, 1, 1]) >>> metric = DotProductRecallAtK(k=2) >>> metric.update(labels=relevance, query=query, keys=keys, indices=indices) DotProductRecallAtK(...) >>> metric.compute() # 2 of 3 relevant items in top-2 Array(0.6666667, dtype=float32)
- reset() Self#
Reset the metric state in-place.
- update(labels: Array, query: Array, keys: Array, indices: Array, *, mask: Array | None = None, **_) Self#
Update the recall@k with a batch of query/key embeddings.
- Parameters:
labels – Relevance labels for indexed items, shape
(*batch_shape, num_sampled).query – Query embeddings, shape
(*batch_shape, num_features).keys – Key embeddings for all candidates, shape
(num_candidates, num_features).indices – Indices into keys for each query, shape
(*batch_shape, num_sampled).mask – Binary mask indicating which queries to include.