Source code for cyclic_boosting.link

"""This module contains some general/canonical link-mean-function pairs such as

- :class:`~LogLinkMixin`
- :class:`~LogitLinkMixin`
"""

from __future__ import absolute_import, division, print_function

import abc

import numexpr
import numpy as np
import six


[docs] @six.add_metaclass(abc.ABCMeta) class LinkFunction(object): r"""Abstract base class for link function computations."""
[docs] @abc.abstractmethod def is_in_range(self, values: np.ndarray): """Check if values can be transformed by the link function.""" pass
[docs] class LogLinkMixin(LinkFunction): r"""Link function and mean function for example for Poisson-distributed data. Supported values are in the range :math:`x > 0`"""
[docs] def is_in_range(self, m: np.ndarray) -> bool: return np.all(m > 0.0)
[docs] class LogitLinkMixin(LinkFunction): r"""Link for the logit transformation. Supported values are in the range :math:`0 \leq x \leq 1` """
[docs] def is_in_range(self, p: np.ndarray) -> bool: return np.all(numexpr.evaluate("(p >= 0.0) & (p <= 1.0)"))
[docs] class IdentityLinkMixin(LinkFunction): """Identity link"""
[docs] def is_in_range(self, m: np.ndarray): return True
__all__ = [ "LinkFunction", "LogLinkMixin", "LogitLinkMixin", "IdentityLinkMixin", ]