Regression Metrics#
- class flax_metrics.regression.LpError(p: float, *, norm: bool = True, transform: Callable | None = None)#
L-p error metric with optional transform.
Computes
mean(|transform(target) - transform(prediction)|^p)^(1/p)ifnorm=True, ormean(|transform(target) - transform(prediction)|^p)ifnorm=False.- Parameters:
p – Exponent of the L-p metric.
norm – Normalize the metric by raising to
1 / ppower.transform – Apply a transformation to the inputs element-wise before evaluating the metric. For example, use
jnp.log1pfor log-space errors.
Example
>>> from jax import numpy as jnp >>> from flax_metrics import LpError >>> >>> target = jnp.array([1.0, 2.0, 3.0]) >>> prediction = jnp.array([1.5, 2.0, 2.5]) >>> metric = LpError(p=2, norm=False) >>> metric.update(target=target, prediction=prediction) LpError(...) >>> metric.compute() Array(0.166..., dtype=float32)
- reset() Self#
Reset the metric state.
- class flax_metrics.regression.MeanAbsoluteError#
Mean absolute error (MAE).
See also
This metric is implemented in scikit-learn as
sklearn.metrics.mean_absolute_error().Example
>>> from jax import numpy as jnp >>> from flax_metrics import MeanAbsoluteError >>> >>> target = jnp.array([1.0, 2.0, 3.0]) >>> prediction = jnp.array([1.5, 2.0, 2.5]) >>> metric = MeanAbsoluteError() >>> metric.update(target=target, prediction=prediction) MeanAbsoluteError(...) >>> metric.compute() Array(0.333..., dtype=float32)
- reset() Self#
Reset the metric state.
- class flax_metrics.regression.MeanSquaredError#
Mean squared error (MSE).
See also
This metric is implemented in scikit-learn as
sklearn.metrics.mean_squared_error().Example
>>> from jax import numpy as jnp >>> from flax_metrics import MeanSquaredError >>> >>> target = jnp.array([1.0, 2.0, 3.0]) >>> prediction = jnp.array([1.5, 2.0, 2.5]) >>> metric = MeanSquaredError() >>> metric.update(target=target, prediction=prediction) MeanSquaredError(...) >>> metric.compute() Array(0.166..., dtype=float32)
- reset() Self#
Reset the metric state.
- class flax_metrics.regression.MeanSquaredLogError#
Mean squared logarithmic error (MSLE).
See also
This metric is implemented in scikit-learn as
sklearn.metrics.mean_squared_log_error().Example
>>> from jax import numpy as jnp >>> from flax_metrics import MeanSquaredLogError >>> >>> target = jnp.array([1.0, 2.0, 3.0]) >>> prediction = jnp.array([1.5, 2.5, 3.5]) >>> metric = MeanSquaredLogError() >>> metric.update(target=target, prediction=prediction) MeanSquaredLogError(...) >>> metric.compute() Array(0.0291..., dtype=float32)
- reset() Self#
Reset the metric state.
- class flax_metrics.regression.RootMeanSquaredError#
Root mean squared error (RMSE).
See also
This metric is implemented in scikit-learn as
sklearn.metrics.root_mean_squared_error().Example
>>> from jax import numpy as jnp >>> from flax_metrics import RootMeanSquaredError >>> >>> target = jnp.array([1.0, 2.0, 3.0]) >>> prediction = jnp.array([1.5, 2.0, 2.5]) >>> metric = RootMeanSquaredError() >>> metric.update(target=target, prediction=prediction) RootMeanSquaredError(...) >>> metric.compute() Array(0.408..., dtype=float32)
- reset() Self#
Reset the metric state.
- class flax_metrics.regression.RootMeanSquaredLogError#
Root mean squared logarithmic error (RMSLE).
Example
>>> from jax import numpy as jnp >>> from flax_metrics import RootMeanSquaredLogError >>> >>> target = jnp.array([1.0, 2.0, 3.0]) >>> prediction = jnp.array([1.5, 2.5, 3.5]) >>> metric = RootMeanSquaredLogError() >>> metric.update(target=target, prediction=prediction) RootMeanSquaredLogError(...) >>> metric.compute() Array(0.170..., dtype=float32)
- reset() Self#
Reset the metric state.