Base Metrics#
- class flax_metrics.base.Average#
Average metric, the arithmetic mean of values.
- Parameters:
argname β Name of the keyword argument to average.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import Average >>> >>> values = jnp.array([1.0, 2.0, 3.0, 4.0]) >>> metric = Average() >>> metric.update(values=values) Average(...) >>> metric.compute() Array(2.5, dtype=float32)
- reset() Self#
Reset the metric state.
- class flax_metrics.base.BaseMetric#
Base class for Flax Metrics implementations.
We inherit from
flax.nnx.metrics.Metricto supportisinstancetype checks. This class overridesupdate()to accept positional and keyword arguments and amaskparameter.update()also returnsSelfsoupdate()s andcompute()can be chained.- compute()#
Compute and return the value of the
Metric.
- reset() Self#
Reset the state of the metric in-place.
- Returns:
The metric instance.
- class flax_metrics.base.Statistics(mean: Array, standard_error_of_mean: Array, standard_deviation: Array)#
Statistics computed by the Welford metric.
- count(value, /)#
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)#
Return first index of value.
Raises ValueError if the value is not present.
- class flax_metrics.base.Welford#
Welford metric, computing running mean and variance using Welfordβs algorithm.
This is useful for computing statistics over a stream of data without storing all values in memory.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import Welford >>> >>> values = jnp.array([1.0, 2.0, 3.0, 4.0]) >>> metric = Welford() >>> metric.update(values=values) Welford(...) >>> metric.compute() Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559..., dtype=float32), standard_deviation=Array(1.118..., dtype=float32))
- compute() Statistics#
Compute and return the mean and variance statistics.
- reset() Self#
Reset the metric state.