Base Metrics#

class flax_metrics.base.Average#

Average metric, the arithmetic mean of values.

Parameters:

argname – Name of the keyword argument to average.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import Average
>>>
>>> values = jnp.array([1.0, 2.0, 3.0, 4.0])
>>> metric = Average()
>>> metric.update(values=values)
Average(...)
>>> metric.compute()
Array(2.5, dtype=float32)
compute() Array#

Compute and return the average.

reset() Self#

Reset the metric state.

update(values: Array, *_args, mask: Array | None = None, **_kwargs) Self#

Update the metric in-place.

Parameters:
  • *args – Positional arguments.

  • mask – Binary mask indicating which elements to include.

  • **kwargs – Keyword arguments.

Returns:

The metric instance.

class flax_metrics.base.BaseMetric#

Base class for Flax Metrics implementations.

We inherit from flax.nnx.metrics.Metric to support isinstance type checks. This class overrides update() to accept positional and keyword arguments and a mask parameter. update() also returns Self so update()s and compute() can be chained.

compute()#

Compute and return the value of the Metric.

reset() Self#

Reset the state of the metric in-place.

Returns:

The metric instance.

update(*args, mask: Array | None = None, **kwargs) Self#

Update the metric in-place.

Parameters:
  • *args – Positional arguments.

  • mask – Binary mask indicating which elements to include.

  • **kwargs – Keyword arguments.

Returns:

The metric instance.

class flax_metrics.base.Statistics(mean: Array, standard_error_of_mean: Array, standard_deviation: Array)#

Statistics computed by the Welford metric.

count(value, /)#

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

mean: Array#

Alias for field number 0

standard_deviation: Array#

Alias for field number 2

standard_error_of_mean: Array#

Alias for field number 1

class flax_metrics.base.Welford#

Welford metric, computing running mean and variance using Welford’s algorithm.

This is useful for computing statistics over a stream of data without storing all values in memory.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import Welford
>>>
>>> values = jnp.array([1.0, 2.0, 3.0, 4.0])
>>> metric = Welford()
>>> metric.update(values=values)
Welford(...)
>>> metric.compute()
Statistics(mean=Array(2.5, dtype=float32),
           standard_error_of_mean=Array(0.559..., dtype=float32),
           standard_deviation=Array(1.118..., dtype=float32))
compute() Statistics#

Compute and return the mean and variance statistics.

reset() Self#

Reset the metric state.

update(values: Array, *_args, mask: Array | None = None, **_kwargs) Self#

Update the metric with new values.

Parameters:
  • values – Array of values to include in the statistics.

  • mask – Binary mask indicating which elements to include.