Regression Metrics#

class flax_metrics.regression.LpError(p: float, *, norm: bool = True, transform: Callable | None = None)#

L-p error metric with optional transform.

Computes mean(|transform(target) - transform(prediction)|^p)^(1/p) if norm=True, or mean(|transform(target) - transform(prediction)|^p) if norm=False.

Parameters:
  • p – Exponent of the L-p metric.

  • norm – Normalize the metric by raising to 1 / p power.

  • transform – Apply a transformation to the inputs element-wise before evaluating the metric. For example, use jnp.log1p for log-space errors.

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import LpError
>>>
>>> target = jnp.array([1.0, 2.0, 3.0])
>>> prediction = jnp.array([1.5, 2.0, 2.5])
>>> metric = LpError(p=2, norm=False)
>>> metric.update(target=target, prediction=prediction)
LpError(...)
>>> metric.compute()
Array(0.166..., dtype=float32)
compute() Array#

Compute and return the L-p error.

reset() Self#

Reset the metric state.

update(target: Array, prediction: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • target – Ground truth values.

  • prediction – Predicted values.

  • mask – Binary mask indicating which elements to include.

class flax_metrics.regression.MeanAbsoluteError#

Mean absolute error (MAE).

See also

This metric is implemented in scikit-learn as sklearn.metrics.mean_absolute_error().

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import MeanAbsoluteError
>>>
>>> target = jnp.array([1.0, 2.0, 3.0])
>>> prediction = jnp.array([1.5, 2.0, 2.5])
>>> metric = MeanAbsoluteError()
>>> metric.update(target=target, prediction=prediction)
MeanAbsoluteError(...)
>>> metric.compute()
Array(0.333..., dtype=float32)
compute() Array#

Compute and return the L-p error.

reset() Self#

Reset the metric state.

update(target: Array, prediction: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • target – Ground truth values.

  • prediction – Predicted values.

  • mask – Binary mask indicating which elements to include.

class flax_metrics.regression.MeanSquaredError#

Mean squared error (MSE).

See also

This metric is implemented in scikit-learn as sklearn.metrics.mean_squared_error().

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import MeanSquaredError
>>>
>>> target = jnp.array([1.0, 2.0, 3.0])
>>> prediction = jnp.array([1.5, 2.0, 2.5])
>>> metric = MeanSquaredError()
>>> metric.update(target=target, prediction=prediction)
MeanSquaredError(...)
>>> metric.compute()
Array(0.166..., dtype=float32)
compute() Array#

Compute and return the L-p error.

reset() Self#

Reset the metric state.

update(target: Array, prediction: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • target – Ground truth values.

  • prediction – Predicted values.

  • mask – Binary mask indicating which elements to include.

class flax_metrics.regression.MeanSquaredLogError#

Mean squared logarithmic error (MSLE).

See also

This metric is implemented in scikit-learn as sklearn.metrics.mean_squared_log_error().

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import MeanSquaredLogError
>>>
>>> target = jnp.array([1.0, 2.0, 3.0])
>>> prediction = jnp.array([1.5, 2.5, 3.5])
>>> metric = MeanSquaredLogError()
>>> metric.update(target=target, prediction=prediction)
MeanSquaredLogError(...)
>>> metric.compute()
Array(0.0291..., dtype=float32)
compute() Array#

Compute and return the L-p error.

reset() Self#

Reset the metric state.

update(target: Array, prediction: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • target – Ground truth values.

  • prediction – Predicted values.

  • mask – Binary mask indicating which elements to include.

class flax_metrics.regression.RootMeanSquaredError#

Root mean squared error (RMSE).

See also

This metric is implemented in scikit-learn as sklearn.metrics.root_mean_squared_error().

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import RootMeanSquaredError
>>>
>>> target = jnp.array([1.0, 2.0, 3.0])
>>> prediction = jnp.array([1.5, 2.0, 2.5])
>>> metric = RootMeanSquaredError()
>>> metric.update(target=target, prediction=prediction)
RootMeanSquaredError(...)
>>> metric.compute()
Array(0.408..., dtype=float32)
compute() Array#

Compute and return the L-p error.

reset() Self#

Reset the metric state.

update(target: Array, prediction: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • target – Ground truth values.

  • prediction – Predicted values.

  • mask – Binary mask indicating which elements to include.

class flax_metrics.regression.RootMeanSquaredLogError#

Root mean squared logarithmic error (RMSLE).

Example

>>> from jax import numpy as jnp
>>> from flax_metrics import RootMeanSquaredLogError
>>>
>>> target = jnp.array([1.0, 2.0, 3.0])
>>> prediction = jnp.array([1.5, 2.5, 3.5])
>>> metric = RootMeanSquaredLogError()
>>> metric.update(target=target, prediction=prediction)
RootMeanSquaredLogError(...)
>>> metric.compute()
Array(0.170..., dtype=float32)
compute() Array#

Compute and return the L-p error.

reset() Self#

Reset the metric state.

update(target: Array, prediction: Array, *, mask: Array | None = None, **_) Self#

Update the metric with a batch of predictions.

Parameters:
  • target – Ground truth values.

  • prediction – Predicted values.

  • mask – Binary mask indicating which elements to include.