Program Listing for File BF_Py.py

Program Listing for File BF_Py.py#

Return to documentation for file (src/tfc/utils/BF_Py.py)

import numpy as np
import jax.numpy as jnp
from abc import ABC, abstractmethod
from tfc.utils.tfc_types import uint, Number, JaxOrNumpyArray
from typing import Callable, Tuple


class BasisFunc(ABC):
    """
    Python implementation of the basis function classes. These are an alternative
    to the C++ versions. They can not be JIT-ed, but they do support alternative
    types, e.g., single float, complex, etc. Even though they cannnot be JIT-ed,
    they can often be used in JIT functions, if their arguments can be removed
    from said functions. For example, when solving an ODE, oftentimes the basis
    functions can be treated as compile time constants. This can be done using
    `pejit`: see `pejit` for more details.
    """

    def __init__(
        self,
        x0: Number,
        xf: Number,
        nC: JaxOrNumpyArray,
        m: uint,
        z0: Number = 0,
        zf: Number = float("inf"),
    ) -> None:
        """
        Initialize the basis class.

        Parameters
        ----------
        x0 : Number
            Start of the problem domain.
        xf : Number
            End of the problem domain.
        nC : JaxOrNumpyArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        z0 : Number
            Start of the basis function domain.
        zf : Number
            End of the basis function domain.
        """

        self._m = m
        self._nC = nC
        self._numC = len(nC)

        self._z0 = z0
        if zf == float("inf"):
            self._c = 1.0
            self._x0 = 0.0
        else:
            self._x0 = x0
            self._c = (zf - z0) / (xf - x0)

    def H(self, x: JaxOrNumpyArray, d: uint = 0, full: bool = False) -> JaxOrNumpyArray:
        """
        Returns the basis function matrix for the x with a derivative of order d.

        Parameters
        ----------
        x : NDArray
            Input array. Values to calculate the basis function for.
        d : uint
            Order of the derivative
        full : bool
            Whether to return the full basis function set, or remove
            the columns associated with self._nC.

        Returns
        -------
        H : NDArray
            The basis function values.
        """

        z = (x - self._x0) * self._c + self._z0
        if len(z.shape) == 1:
            z = np.expand_dims(z, 1)
        dMult = self._c**d
        F = self._Hint(z, d) * dMult
        if not full and self._numC > 0:
            F = np.delete(F, self._nC, axis=1)
        return F

    @abstractmethod
    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the basis function value.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """
        pass

    @property
    def c(self) -> Number:
        """
        Return the constants that map the problem domain to the basis
        function domain.

        Returns
        -------
        float
            The constant that maps the problem domain to the basis function
            domain.
        """

        return self._c


class CP(BasisFunc):
    """
    Chebyshev polynomial basis functions.
    """

    def __init__(
        self,
        x0: Number,
        xf: Number,
        nC: JaxOrNumpyArray,
        m: uint,
    ) -> None:
        """
        Initialize the basis class.

        Parameters
        ----------
        x0 : Number
            Start of the problem domain.
        xf : Number
            End of the problem domain.
        nC : JaxOrNumpyArray
            Basis functions to be removed
        m:  uint
            Number of basis functions.
        """
        super().__init__(x0, xf, nC, m, -1.0, 1.0)

    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the CP basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """
        N = np.size(z)
        One = np.ones_like(z)
        Zero = np.zeros_like(z)
        if self._m_m == 1:
            if d > 0:
                F = Zero
            else:
                F = One
            return F
        elif self._m_m == 2:
            if d > 1:
                F = np.hstack((Zero, Zero))
            elif d > 0:
                F = np.hstack((Zero, One))
            else:
                F = np.hstack((One, z))
            return F
        else:
            F = np.hstack((One, z, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
            for k in range(2, self._m_m):
                F[:, k : k + 1] = 2 * z * F[:, k - 1 : k] - F[:, k - 2 : k - 1]

            def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray:
                """
                Take derivative recursively.
                """
                if dCurr == d:
                    return dark
                else:
                    if dCurr == 0:
                        dark2 = np.hstack((Zero, One, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
                    else:
                        dark2 = np.zeros((N, self._m_m), dtype=z.dtype)
                    for k in range(2, self._m_m):
                        dark2[:, k : k + 1] = (
                            (2 + 2 * dCurr) * dark[:, k - 1 : k]
                            + 2 * z * dark2[:, k - 1 : k]
                            - dark2[:, k - 2 : k - 1]
                        )
                    dCurr += 1
                    return Recurse(dark2, d, dCurr=dCurr)

            return Recurse(F, d)


class LeP(BasisFunc):
    """
    Legendre polynomial basis functions.
    """

    def __init__(
        self,
        x0: Number,
        xf: Number,
        nC: JaxOrNumpyArray,
        m: uint,
    ) -> None:
        """
        Initialize the basis class.

        Parameters
        ----------
        x0 : Number
            Start of the problem domain.
        xf : Number
            End of the problem domain.
        nC : JaxOrNumpyArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        """
        super().__init__(x0, xf, nC, m, -1.0, 1.0)

    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the LeP basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """
        N = np.size(z)
        One = np.ones_like(z)
        Zero = np.zeros_like(z)
        if self._m_m == 1:
            if d > 0:
                F = Zero
            else:
                F = One
            return F
        elif self._m_m == 2:
            if d > 1:
                F = np.hstack((Zero, Zero))
            elif d > 0:
                F = np.hstack((Zero, One))
            else:
                F = np.hstack((One, z))
            return F
        else:
            F = np.hstack((One, z, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
            for k in range(1, self._m_m - 1):
                F[:, k + 1 : k + 2] = (
                    (2.0 * k + 1.0) * z * F[:, k : k + 1] - k * F[:, k - 1 : k]
                ) / (k + 1.0)

            def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray:
                """
                Take derivative recursively.
                """
                if dCurr == d:
                    return dark
                else:
                    if dCurr == 0:
                        dark2 = np.hstack((Zero, One, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
                    else:
                        dark2 = np.zeros((N, self._m_m), dtype=z.dtype)
                    for k in range(1, self._m_m - 1):
                        dark2[:, k + 1 : k + 2] = (
                            (2.0 * k + 1.0)
                            * ((dCurr + 1.0) * dark[:, k : k + 1] + z * dark2[:, k : k + 1])
                            - k * dark2[:, k - 1 : k]
                        ) / (k + 1.0)
                    dCurr += 1
                    return Recurse(dark2, d, dCurr=dCurr)

            return Recurse(F, d)


class LaP(BasisFunc):
    """
    Laguerre polynomial basis functions.
    """

    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the LaP basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """
        N = np.size(z)
        One = np.ones_like(z)
        Zero = np.zeros_like(z)
        if self._m_m == 1:
            if d > 0:
                F = Zero
            else:
                F = One
            return F
        elif self._m_m == 2:
            if d > 1:
                F = np.hstack((Zero, Zero))
            elif d > 0:
                F = np.hstack((Zero, -One))
            else:
                F = np.hstack((One, 1.0 - z))
            return F
        else:
            F = np.hstack((One, 1.0 - z, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
            for k in range(1, self._m_m - 1):
                F[:, k + 1 : k + 2] = (
                    (2.0 * k + 1.0 - z) * F[:, k : k + 1] - k * F[:, k - 1 : k]
                ) / (k + 1.0)

            def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray:
                """
                Take derivative recursively.
                """
                if dCurr == d:
                    return dark
                else:
                    if dCurr == 0:
                        dark2 = np.hstack((Zero, -One, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
                    else:
                        dark2 = np.zeros((N, self._m_m), dtype=z.dtype)
                    for k in range(1, self._m_m - 1):
                        dark2[:, k + 1 : k + 2] = (
                            (2.0 * k + 1.0 - z) * dark2[:, k : k + 1]
                            - (dCurr + 1.0) * dark[:, k : k + 1]
                            - k * dark2[:, k - 1 : k]
                        ) / (k + 1.0)
                    dCurr += 1
                    return Recurse(dark2, d, dCurr=dCurr)

            return Recurse(F, d)


class HoPpro(BasisFunc):
    """
    Hermite probablist polynomial basis functions.
    """

    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the HoPpro basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function valuesa
        """
        N = np.size(z)
        One = np.ones_like(z)
        Zero = np.zeros_like(z)
        if self._m_m == 1:
            if d > 0:
                F = Zero
            else:
                F = One
            return F
        elif self._m_m == 2:
            if d > 1:
                F = np.hstack((Zero, Zero))
            elif d > 0:
                F = np.hstack((Zero, One))
            else:
                F = np.hstack((One, z))
            return F
        else:
            F = np.hstack((One, z, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
            for k in range(1, self._m_m - 1):
                F[:, k + 1 : k + 2] = z * F[:, k : k + 1] - k * F[:, k - 1 : k]

            def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray:
                """
                Take derivative recursively.
                """
                if dCurr == d:
                    return dark
                else:
                    if dCurr == 0:
                        dark2 = np.hstack((Zero, One, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
                    else:
                        dark2 = np.zeros((N, self._m_m), dtype=z.dtype)
                    for k in range(1, self._m_m - 1):
                        dark2[:, k + 1 : k + 2] = (
                            (dCurr + 1.0) * dark[:, k : k + 1]
                            + z * dark2[:, k : k + 1]
                            - k * dark2[:, k - 1 : k]
                        )
                    dCurr += 1
                    return Recurse(dark2, d, dCurr=dCurr)

            return Recurse(F, d)


class HoPphy(BasisFunc):
    """
    Hermite physicist polynomial basis functions.
    """

    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the HoPpro basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function valuesa
        """
        N = np.size(z)
        One = np.ones_like(z)
        Zero = np.zeros_like(z)
        if self._m_m == 1:
            if d > 0:
                F = Zero
            else:
                F = One
            return F
        elif self._m_m == 2:
            if d > 1:
                F = np.hstack((Zero, Zero))
            elif d > 0:
                F = np.hstack((Zero, 2.0 * One))
            else:
                F = np.hstack((One, 2.0 * z))
            return F
        else:
            F = np.hstack((One, 2.0 * z, np.zeros((N, self._m_m - 2), dtype=z.dtype)))
            for k in range(1, self._m_m - 1):
                F[:, k + 1 : k + 2] = 2.0 * z * F[:, k : k + 1] - 2.0 * k * F[:, k - 1 : k]

            def Recurse(dark: JaxOrNumpyArray, d: uint, dCurr: uint = 0) -> JaxOrNumpyArray:
                """
                Take derivative recursively.
                """
                if dCurr == d:
                    return dark
                else:
                    if dCurr == 0:
                        dark2 = np.hstack(
                            (Zero, 2.0 * One, np.zeros((N, self._m_m - 2), dtype=z.dtype))
                        )
                    else:
                        dark2 = np.zeros((N, self._m_m), dtype=z.dtype)
                    for k in range(1, self._m_m - 1):
                        dark2[:, k + 1 : k + 2] = (
                            2.0 * (dCurr + 1.0) * dark[:, k : k + 1]
                            + 2.0 * z * dark2[:, k : k + 1]
                            - 2.0 * k * dark2[:, k - 1 : k]
                        )
                    dCurr += 1
                    return Recurse(dark2, d, dCurr=dCurr)

            return Recurse(F, d)


class FS(BasisFunc):
    """
    Chebyshev polynomial basis functions.
    """

    def __init__(
        self,
        x0: Number,
        xf: Number,
        nC: JaxOrNumpyArray,
        m: uint,
    ) -> None:
        """
        Initialize the basis class.

        Parameters
        ----------
        x0 : Number
            Start of the problem domain.
        xf : Number
            End of the problem domain.
        nC : JaxOrNumpyArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        """
        super().__init__(x0, xf, nC, m, -np.pi, np.pi)

    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the CP basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """
        N = np.size(z)
        F = np.zeros((N, self._m_m))
        if d == 0:
            F[:, 0] = 1.0
            for k in range(1, self._m_m):
                g = np.ceil(k / 2.0)
                if k % 2 == 0:
                    F[:, k : k + 1] = np.cos(g * z)
                else:
                    F[:, k : k + 1] = np.sin(g * z)
        else:
            F[:, 0] = 0.0
            if d % 4 == 0:
                for k in range(1, self._m_m):
                    g = np.ceil(k / 2.0)
                    if k % 2 == 0:
                        F[:, k : k + 1] = g**d * np.cos(g * z)
                    else:
                        F[:, k : k + 1] = g**d * np.sin(g * z)
            elif d % 4 == 1:
                for k in range(1, self._m_m):
                    g = np.ceil(k / 2.0)
                    if k % 2 == 0:
                        F[:, k : k + 1] = -(g**d) * np.sin(g * z)
                    else:
                        F[:, k : k + 1] = g**d * np.cos(g * z)
            elif d % 4 == 2:
                for k in range(1, self._m_m):
                    g = np.ceil(k / 2.0)
                    if k % 2 == 0:
                        F[:, k : k + 1] = -(g**d) * np.cos(g * z)
                    else:
                        F[:, k : k + 1] = -(g**d) * np.sin(g * z)
            else:
                for k in range(1, self._m_m):
                    g = np.ceil(k / 2.0)
                    if k % 2 == 0:
                        F[:, k : k + 1] = g**d * np.sin(g * z)
                    else:
                        F[:, k : k + 1] = -(g**d) * np.cos(g * z)
        return F


class ELM(BasisFunc):
    """
    Extreme learning machine abstract basis class.
    """

    def __init__(
        self,
        x0: Number,
        xf: Number,
        nC: JaxOrNumpyArray,
        m: uint,
    ) -> None:
        """
        Initialize the basis class.

        Parameters
        ----------
        x0 : Number
            Start of the problem domain.
        xf : Number
            End of the problem domain.
        nC : JaxOrNumpyArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        """
        super().__init__(x0, xf, nC, m, 0.0, 1.0)

        dtype = np.array(self._c).dtype
        one = np.ones(1, dtype=dtype)

        self._w = np.random.uniform(low=-10.0, high=10.0, size=self._m) * one
        self._w = self._w.reshape((1, self._m))
        self._b = np.random.uniform(low=-10.0, high=10.0, size=self._m) * one
        self._b = self._b.reshape((1, self._m))

    @property
    def w(self) -> JaxOrNumpyArray:
        """
        Weights of the ELM

        Returns
        -------
        NDArray
            Weights of the ELM.
        """
        return self._w

    @property
    def b(self) -> JaxOrNumpyArray:
        """
        Biases of the ELM

        Returns
        -------
        NDArray
            Biases of the ELM.
        """
        return self._b

    @w.setter
    def w(self, val: JaxOrNumpyArray) -> None:
        """
        Weights of the ELM.

        Parameters
        ----------
        val : NDArray
            New weights.
        """
        if val.size == self._m:
            self._w = val
            if self._w.shape != (1, self._m):
                self._w = self._w.reshape((1, self._m))
        else:
            raise ValueError(
                f"Input array of size {val.size} was received, but size {self._m} was expected."
            )

    @b.setter
    def b(self, val: JaxOrNumpyArray) -> None:
        """
        Biases of the ELM.

        Parameters
        ----------
        val : NDArray
            New biases.
        """
        if val.size == self._m:
            self._b = val
            if self._b.shape != (1, self._m):
                self._b = self._b.reshape((1, self._m))
        else:
            raise ValueError(
                f"Input array of size {val.size} was received, but size {self._m} was expected."
            )


class ELMReLU(ELM):
    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the ELMRelu basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """

        if d == 0:
            return np.maximum(0.0, self._w * z + self._b_b)
        elif d == 1:
            return self._w * np.where(self._w * z + self._b_b > 0.0, 1.0, 0.0)
        else:
            return np.zeros((self._m_m, z.size))


class ELMSigmoid(ELM):
    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the ELMSigmoid basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """

        from tfc.utils import egrad

        f = lambda x: 1.0 / (1.0 + jnp.exp(-self._w * x - self._b))

        def Recurse(
            dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dCurr: uint = 0
        ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]:
            """
            Take derivative recursively.
            """
            if dCurr == d:
                return dark
            else:
                dark2 = egrad(dark)
                dCurr += 1
                return Recurse(dark2, d, dCurr=dCurr)

        return np.asarray(Recurse(f, d)(z))


class ELMTanh(ELM):
    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the ELMTanh basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """

        from tfc.utils import egrad

        f = lambda x: jnp.tanh(self._w * x + self._b)

        def Recurse(
            dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dCurr: uint = 0
        ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]:
            """
            Take derivative recursively.
            """
            if dCurr == d:
                return dark
            else:
                dark2 = egrad(dark)
                dCurr += 1
                return Recurse(dark2, d, dCurr=dCurr)

        return np.asarray(Recurse(f, d)(z))


class ELMSin(ELM):
    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the ELMSin basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """

        from tfc.utils import egrad

        f = lambda x: jnp.sin(self._w * x + self._b)

        def Recurse(
            dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dCurr: uint = 0
        ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]:
            """
            Take derivative recursively.
            """
            if dCurr == d:
                return dark
            else:
                dark2 = egrad(dark)
                dCurr += 1
                return Recurse(dark2, d, dCurr=dCurr)

        return np.asarray(Recurse(f, d)(z))


class ELMSwish(ELM):
    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the ELMSwish basis function values.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : uint
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """

        from tfc.utils import egrad

        f = lambda x: (self._w * x + self._b) / (1.0 + jnp.exp(-self._w * x - self._b))

        def Recurse(
            dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dCurr: uint = 0
        ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]:
            """
            Take derivative recursively.
            """
            if dCurr == d:
                return dark
            else:
                dark2 = egrad(dark)
                dCurr += 1
                return Recurse(dark2, d, dCurr=dCurr)

        return np.asarray(Recurse(f, d)(z))


class nBasisFunc(BasisFunc):
    """
    Python implementation of the n-dimensional basis function classes.
    See the Python implementation of `BasisFunc` for details.
    """

    def __init__(
        self,
        x0: JaxOrNumpyArray,
        xf: JaxOrNumpyArray,
        nC: JaxOrNumpyArray,
        m: uint,
        z0: Number = 0.0,
        zf: Number = 0.0,
    ) -> None:
        """
        Initialize the basis class.

        Parameters
        ----------
        x0 : NDArray
            Start of the problem domain.
        xf : NDArray
            End of the problem domain.
        nC : NDArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        z0 : Number
            Start of the basis function domain.
        zf : Number
            End of the basis function domain.
        """

        self._m_m = m
        self._nC_nC = nC
        self._dim = nC.shape[0]
        self._numC_numC = nC.shape[1]

        self._z0_z0 = z0
        self._zf = zf
        self._x0_x0 = x0
        if self._x0_x0.shape != (self._dim, 1):
            self._x0_x0 = self._x0_x0.reshape((self._dim, 1))
        if xf.shape != (self._dim, 1):
            xf = xf.reshape((self._dim, 1))
        self._c_c = (zf - z0) / (xf - self._x0_x0)

        vec = np.zeros((self._dim, 1))
        self._numBasisFunc = self._NumBasisFunc(self._dim - 1, vec, full=False)
        self._numBasisFuncFull = self._NumBasisFunc(self._dim - 1, vec, full=True)

    def _NumBasisFunc(self, dim: int, vec: JaxOrNumpyArray, n: int = 0, full: bool = False) -> int:
        """
        Calculate the number of basis functions.

        Parameters
        ----------
        dim : int
            Number of dimensions.
        vec : NDArray
            Vector used to keep track of the order of the basis function.
        n : int, optional
            Count of the number of basis functions so far. (Default value = 0)
        full : bool, optional
            If true, then does not remove basis functions based on self._nC. (Default value = False)

        Returns
        -------
        int
            Number of basis functions.
        """
        if dim > 0:
            for x in range(self._m_m):
                vec[dim] = x
                n = self._NumBasisFunc(dim - 1, vec, n=n, full=full)
        else:
            for x in range(self._m_m):
                vec[dim] = x
                if full:
                    if np.sum(vec) <= self._m_m - 1:
                        # If the degree of the produce of univariate basis functions is less than
                        # the degree specified, then add one to the count.
                        n += 1
                else:
                    if not np.all(np.any(vec == self._nC_nC, axis=1)) and np.sum(vec) <= self._m_m - 1:
                        # If at least one of the dimensions' basis functions is not a constraint
                        # and the degree of the product of univariate basis functions is less than
                        # the degree specified, add one to the count
                        n += 1
        return n

    @property
    def c(self) -> JaxOrNumpyArray:
        """
        Return the constants that map the problem domain to the basis
        function domain.

        Returns
        -------
        JaxOrNumpyArray
            The constants that map the problem domain to the basis function
            domain.
        """

        return self._c_c

    @property
    def numBasisFunc(self) -> float:
        """
        Return the number of basis functions once user-specified
        functions have been removed.

        Returns
        -------
        float:
            The number of basis functions once the user-specified
            functions have been removed.
        """

        return self._numBasisFunc

    @property
    def numBasisFuncFull(self) -> float:
        """
        Return the number of basis functions before the user-specified
        functions have been removed.

        Returns
        -------
        float:
            The number of basis functions before the user-specified
            functions have been removed.
        """

        return self._numBasisFuncFull

    def H(self, x: JaxOrNumpyArray, d: JaxOrNumpyArray, full: bool = False) -> JaxOrNumpyArray:
        """
        Returns the basis function matrix for the x with a derivative of order d.

        Parameters
        -----------
        x : NDArray
            Input array. Values to calculate the basis function for.
            Should be size dim x N.
        d : NDArray
            Order of the derivative
        full: bool
            Whether to return the full basis function set, or remove
            the columns associated with self._nC.

        Returns
        -------
        H : NDArray
            The basis function values.
        """

        # Check dimensions
        N = x.shape[1]
        if x.shape[0] != self._dim:
            raise ValueError(
                f"Incorrect dimension for x. Expected {self._dim} but got {z.shape[1]}."
            )

        # Convert to basis function domain
        z = (x - self._x0_x0) * self._c_c + self._z0_z0

        # Create individual basis functions for each dimension
        T = np.zeros((N, self._m_m, self._dim), dtype=z.dtype)
        for k in range(self._dim):
            T[:, :, k] = self._Hint(z[k : k + 1, :].T, d[k]) * self._c_c[k] ** d[k]

        # Define functions for use in generating the CP sheet
        def MultT(vec: JaxOrNumpyArray) -> JaxOrNumpyArray:
            """
            Creates basis functions for the multidimensional case by mulitplying the basis functions
            for the single dimensional cases together.

            Parameters
            ----------
            vec : NDArray
                Used to track the basis functions used from the single dimensional cases.

            Returns
            -------
            NDArray
                Basis functions for the multidimensional case.
            """
            tout = np.ones((N, 1), dtype=z.dtype)
            for k in range(self._dim):
                tout *= T[:, vec[k, 0] : vec[k, 0] + 1, k]
            return tout

        def Recurse(
            dim: int, out: JaxOrNumpyArray, vec: JaxOrNumpyArray, n: int = 0, full: bool = False
        ) -> Tuple[JaxOrNumpyArray, int]:
            """
            Creates basis functions for the multidimensional case given the basis functions
            for the single dimensional cases.

            Parameters
            ----------
            dim : int
                Number of dimensions.
            out : NDArray
                Basis function for the multidimensional case created so far.
            n : int, optional
                Count of the number of basis functions created so far. (Default value = 0)
            full : bool, optional
                If true, then does not remove basis functions based on self._nC. (Default value = False)

            Returns
            -------
            out : NDArraY
                Basis functions for the multidimensional case created so far.
            n : int
                Basis function count.
            """
            if dim > 0:
                for x in range(self._m_m):
                    vec[dim] = x
                    out, n = Recurse(dim - 1, out, vec, n=n, full=full)
            else:
                for x in range(self._m_m):
                    vec[dim] = x
                    if full:
                        if np.sum(vec) <= self._m_m - 1:
                            # If the degree of the produce of univariate basis functions is less than
                            # the degree specified, then include this vector.
                            out[:, n : n + 1] = MultT(vec)
                            n += 1
                    else:
                        if (
                            not np.all(np.any(vec == self._nC_nC, axis=1))
                            and np.sum(vec) <= self._m_m - 1
                        ):
                            # If at least one of the dimensions' basis functions is not a constraint
                            # and the degree of the product of univariate basis functions is less than
                            # the degree specified, include this vector.
                            out[:, n : n + 1] = MultT(vec)
                            n += 1
            return out, n

        # Calculate and store all possible combinations of the individual basis functions
        vec = np.zeros((self._dim, 1), dtype=int)
        if full:
            out = np.zeros((N, self._numBasisFuncFull), dtype=z.dtype)
        else:
            out = np.zeros((N, self._numBasisFunc), dtype=z.dtype)
        out, _ = Recurse(self._dim - 1, out, vec, full=full)

        return out


class nCP(nBasisFunc, CP):
    """
    n-dimensional Chebyshev polynomial basis functions.
    """

    def __init__(
        self,
        x0: JaxOrNumpyArray,
        xf: JaxOrNumpyArray,
        nC: JaxOrNumpyArray,
        m: uint,
    ) -> None:
        """
        Initialize the n-dimensional CP class.

        Parameters
        ----------
        x0 : NDArray
            Start of the problem domain.
        xf : NDArray
            End of the problem domain.
        nC : NDArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        """

        nBasisFunc.__init__(self, x0, xf, nC, m, -1.0, 1.0)


class nLeP(nBasisFunc, LeP):
    """
    n-dimensional Legendre polynomial basis functions.
    """

    def __init__(
        self,
        x0: JaxOrNumpyArray,
        xf: JaxOrNumpyArray,
        nC: JaxOrNumpyArray,
        m: uint,
    ) -> None:
        """
        Initialize the n-dimensional LeP class.

        Parameters
        ----------
        x0 : NDArray
            Start of the problem domain.
        xf : NDArray
            End of the problem domain.
        nC : NDArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        """

        nBasisFunc.__init__(self, x0, xf, nC, m, -1.0, 1.0)


class nFS(nBasisFunc, FS):
    """
    n-dimensional Fourier series basis functions.
    """

    def __init__(
        self,
        x0: JaxOrNumpyArray,
        xf: JaxOrNumpyArray,
        nC: JaxOrNumpyArray,
        m: uint,
    ) -> None:
        """
        Initialize the n-dimensional FS class.

        Parameters
        ----------
        x0 : NDArray
            Start of the problem domain.
        xf : NDArray
            End of the problem domain.
        nC : NDArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        """

        nBasisFunc.__init__(self, x0, xf, nC, m, -np.pi, np.pi)


class nELM(nBasisFunc):
    """
    n-dimensional extreme learning machine abstract basis class.
    """

    def __init__(
        self,
        x0: JaxOrNumpyArray,
        xf: JaxOrNumpyArray,
        nC: JaxOrNumpyArray,
        m: uint,
        z0: Number = 0.0,
        zf: Number = 1.0,
    ) -> None:
        """
        Initialize the basis class.

        Parameters
        ----------
        x0 : NDArray
            Start of the problem domain.
        xf : NDArray
            End of the problem domain.
        nC : NDArray
            Basis functions to be removed
        m : uint
            Number of basis functions.
        z0 : Number
            Start of the basis function domain.
        zf : Number
            End of the basis function domain.
        """

        self._m_m_m = m
        self._nC_nC_nC = nC
        self._dim_dim = x0.size

        if np.any(self._nC_nC_nC != -1):
            self._numC_numC_numC = nC.size
        else:
            self._numC_numC_numC = 0

        self._z0_z0_z0 = z0
        self._zf_zf = zf
        self._x0_x0_x0 = x0
        if self._x0_x0_x0.shape != (self._dim_dim, 1):
            self._x0_x0_x0 = self._x0_x0_x0.reshape((self._dim_dim, 1))
        if xf.shape != (self._dim_dim, 1):
            xf = xf.reshape((self._dim_dim, 1))
        self._c_c_c = (zf - z0) / (xf - self._x0_x0_x0)

        self._numBasisFunc_numBasisFunc = self._m_m_m - self._numC_numC_numC
        self._numBasisFuncFull_numBasisFuncFull = self._m_m_m

        one = np.ones(1, dtype=x0.dtype)
        self._w = np.random.uniform(low=-1.0, high=1.0, size=self._dim_dim * self._m_m_m) * one
        self._w = self._w.reshape((self._dim_dim, self._m_m_m))
        self._b = np.random.uniform(low=-1.0, high=1.0, size=self._m_m_m) * one
        self._b = self._b.reshape((1, self._m_m_m))

    @property
    def w(self) -> JaxOrNumpyArray:
        """
        Weights of the nELM

        Returns
        -------
        NDArray
            Weights of the ELM.
        """
        return self._w

    @property
    def b(self) -> JaxOrNumpyArray:
        """
        Biases of the nELM

        Returns
        -------
        NDArray
            Biases of the ELM.
        """
        return self._b

    @w.setter
    def w(self, val: JaxOrNumpyArray) -> None:
        """
        Weights of the nELM.

        Parameters
        ----------
        val : NDArray
            New weights.
        """
        if val.size == self._m_m_m * self._dim_dim:
            self._w = val
            if self._w.shape != (self._dim_dim, self._m_m_m):
                self._w = self._w.reshape((self._dim_dim, self._m_m_m))
        else:
            raise ValueError(
                f"Input array of size {val.size} was received, but size {self._m*self._dim} was expected."
            )

    @b.setter
    def b(self, val: JaxOrNumpyArray) -> None:
        """
        Biases of the nELM.

        Parameters
        ----------
        val : NDArray
            New biases.
        """
        if val.size == self._m_m_m:
            self._b = val
            if self._b.shape != (1, self._m_m_m):
                self._b = self._b.reshape((1, self._m_m_m))
        else:
            raise ValueError(
                f"Input array of size {val.size} was received, but size {self._m} was expected."
            )

    def H(self, x: JaxOrNumpyArray, d: JaxOrNumpyArray, full: bool = False) -> JaxOrNumpyArray:
        """
        Returns the basis function matrix for the x with a derivative of order d.

        Parameters
        ----------
        x : NDArray
            Input array. Values to calculate the basis function for.
            Should be size dim x N.
        d : NDArray
            Order of the derivative
        full : bool
            Whether to return the full basis function set, or remove
            the columns associated with self._nC.

        Returns
        -------
        H : NDArray
            The basis function values.
        """

        # Check dimensions
        if x.shape[0] != self._dim_dim:
            raise ValueError(
                f"Incorrect dimension for x. Expected {self._dim} but got {z.shape[1]}."
            )

        # Convert to basis function domain
        z = ((x - self._x0_x0_x0) * self._c_c_c + self._z0_z0_z0).T

        F = self._nHint(z, d)
        if not full and self._numC_numC_numC > 0:
            F = np.delete(F, self._nC_nC_nC, axis=1)
        return F

    @abstractmethod
    def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the basis function value.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : NDArray
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """
        pass

    def _Hint(self, z: JaxOrNumpyArray, d: uint) -> JaxOrNumpyArray:
        """
        Dummy function, this should never be called!
        """
        raise ValueError("Error: This function should never be called.")


class nELMReLU(nELM):
    """
    n-dimensional ELM ReLU basis functions.
    """

    def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the basis function value.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : NDArray
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """
        ind = -1
        zeroFlag = False
        dorder = np.sum(d)
        if dorder > 1:
            zeroFlag = True
        elif dorder == 1:
            ind = np.where(d == 1)[0][0]

        if zeroFlag:
            # Derivative order is high enough that everything is zeros
            return np.zeros((z.shape[0], self._m_m_m_m))
        elif ind != -1:
            # We have a derivative on only one variable
            return (
                self._c_c_c[ind]
                * self._w_w[ind : ind + 1, :]
                * np.where(np.dot(z, self._w_w) + self._b_b > 0.0, 1.0, 0.0)
            )
        else:
            return np.maximum(0.0, np.dot(z, self._w_w) + self._b_b)


class nELMSin(nELM):
    """
    n-dimensional ELM sin basis functions.
    """

    def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the basis function value.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : NDArray
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """

        from tfc.utils import egrad

        f = lambda *x: jnp.sin(jnp.dot(jnp.hstack(x), self._w) + self._b)

        z = jnp.split(z, z.shape[1], axis=1)

        def Recurse(
            dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0
        ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]:
            if dCurr == d:
                return dark
            else:
                dark2 = egrad(dark, dim)
                dCurr += 1
                return Recurse(dark2, d, dim, dCurr=dCurr)

        dark = f
        dark2 = 1
        for dim, deriv in enumerate(d):
            dark2 *= self._c_c_c[dim] ** deriv
            dark = Recurse(dark, deriv, dim)

        return np.asarray((dark(*z) * dark2))


class nELMTanh(nELM):
    """
    n-dimensional ELM tanh basis functions.
    """

    def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the basis function value.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : NDArray
            Derivative order.

        Returns
        -------
        H: NDArray
            Basis function values.
        """

        from tfc.utils import egrad

        f = lambda *x: jnp.tanh(jnp.dot(jnp.hstack(x), self._w) + self._b)

        z = jnp.split(z, z.shape[1], axis=1)

        def Recurse(
            dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0
        ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]:
            if dCurr == d:
                return dark
            else:
                dark2 = egrad(dark, dim)
                dCurr += 1
                return Recurse(dark2, d, dim, dCurr=dCurr)

        dark = f
        dark2 = 1
        for dim, deriv in enumerate(d):
            dark2 *= self._c_c_c[dim] ** deriv
            dark = Recurse(dark, deriv, dim)

        return np.asarray((dark(*z) * dark2))


class nELMSigmoid(nELM):
    """
    n-dimensional ELM sigmoid basis functions.
    """

    def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the basis function value.

        Parameters:
        -----------
        z: NDArray
            Values to calculate the basis functions for.
        d: NDArray
            Derivative order.

        Returns:
        --------
        H: NDArray
            Basis function values.
        """

        from tfc.utils import egrad

        f = lambda *x: 1.0 / (1.0 + jnp.exp(-jnp.dot(jnp.hstack(x), self._w) - self._b))

        z = jnp.split(z, z.shape[1], axis=1)

        def Recurse(
            dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0
        ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]:
            if dCurr == d:
                return dark
            else:
                dark2 = egrad(dark, dim)
                dCurr += 1
                return Recurse(dark2, d, dim, dCurr=dCurr)

        dark = f
        dark2 = 1
        for dim, deriv in enumerate(d):
            dark2 *= self._c_c_c[dim] ** deriv
            dark = Recurse(dark, deriv, dim)

        return np.asarray((dark(*z) * dark2))


class nELMSwish(nELM):
    """
    n-dimensional ELM swish basis functions.
    """

    def _nHint(self, z: JaxOrNumpyArray, d: JaxOrNumpyArray) -> JaxOrNumpyArray:
        """
        Internal method used to calcualte the basis function value.

        Parameters
        ----------
        z : NDArray
            Values to calculate the basis functions for.
        d : NDArray
            Derivative order.

        Returns
        -------
        H : NDArray
            Basis function values.
        """

        from tfc.utils import egrad

        def f(*x):
            dark = jnp.dot(jnp.hstack(x), self._w) + self._b
            return dark / (1.0 + jnp.exp(-dark))

        z = jnp.split(z, z.shape[1], axis=1)

        def Recurse(
            dark: Callable[[JaxOrNumpyArray], jnp.ndarray], d: uint, dim: uint, dCurr: uint = 0
        ) -> Callable[[JaxOrNumpyArray], jnp.ndarray]:
            if dCurr == d:
                return dark
            else:
                dark2 = egrad(dark, dim)
                dCurr += 1
                return Recurse(dark2, d, dim, dCurr=dCurr)

        dark = f
        dark2 = 1
        for dim, deriv in enumerate(d):
            dark2 *= self._c_c_c[dim] ** deriv
            dark = Recurse(dark, deriv, dim)

        return np.asarray((dark(*z) * dark2))