Source code for tfc.utfc

from jax._src.config import config

config.update("jax_enable_x64", True)

import numpy as onp
import jax.numpy as np
import numpy.typing as npt
from typing import Optional, cast
from .utils.tfc_types import Literal, uint, IntArrayLike, JaxOrNumpyArray
from jax import core
from jax.extend.core import Primitive
from jax.interpreters import ad, batching, mlir
from jax.ffi import register_ffi_target
import jaxlib.mlir.ir as ir

# This is not part of the public API. However, it is what JAX uses internally in the ffi
# interface. We need this here, since we want to do very low-level things, like injecting
# new operands that are not traced into the C++ code.
# To switch to the new FFI interface, we would need to re-work all the C++ code to take
# in arguments as a JSON string. This would make the C++ way more confusing than it needs to be.
from jax._src.interpreters import mlir as mlir_int

from .utils.TFCUtils import TFCPrint


[docs] class utfc: """ This is the univariate TFC class. It acts as a container that creates and stores: * The linear map between the free function domain (z) and the problem domain (x). * The basis functions or ELMs that make up the free function. * The necessary JAX code that enables automatic differentiation of the free function. * Other useful TFC related functions such as collocation point creation. In addition, this class ties these methods together to form a utility that enables a higher level of code abstraction such that the end-user scripts are simple, clear, and elegant implementations of TFC. Parameters ---------- N : int Number of points to use when discretizing the domain. nC : IntArrayLike Number of functions to remove from the free function linear expansion. This variable is used to account for basis functions that are linearly dependent on support functions used in the construction of the constrained expressions. It can be expressed in 1 of 2 ways. 1. As an integer. When expressed as an integer, the first nC basis functions are removed from the free function. 2. As a list or array. When expressed as a list or array, the basis functions corresponding to the numbers given by the list or array are removed from the free function. deg : int Degree of the basis function expansion. This number is one less than the number of basis functions used before removing those specified by nC. x0 : float, optional Specifies the beginning of the DE domain. (Default value = 0) xf : float This required keyword argument specifies the end of the DE domain. basis : Literal["CP","LeP","FS","ELMTanh","ELMSigmoid","ELMSin","ELMSwish","ELMReLU"], optional This optional keyword argument specifies the basis functions to be used. (Default value = "CP") """ def __init__( self, N: uint, nC: IntArrayLike, deg: uint, basis: Literal[ "CP", "LeP", "FS", "ELMTanh", "ELMSigmoid", "ELMSin", "ELMSwish", "ELMReLU" ] = "CP", x0: Optional[float] = None, xf: Optional[float] = None, backend: Literal["C++", "Python"] = "C++", ): """ Constructor for the utfc class. Parameters ---------- N : int Number of points to use when discretizing the domain. nC : int or list or array-like Number of functions to remove from the free function linear expansion. This variable is used to account for basis functions that are linearly dependent on support functions used in the construction of the constrained expressions. It can be expressed in 1 of 2 ways. 1. As an integer. When expressed as an integer, the first nC basis functions are removed from the free function. 2. As a list or array. When expressed as a list or array, the basis functions corresponding to the numbers given by the list or array are removed from the free function. deg : int Degree of the basis function expansion. This number is one less than the number of basis functions used before removing those specified by nC. x0 : float, optional Specifies the beginning of the DE domain. (Default value = 0) xf : float This required keyword argument specifies the end of the DE domain. basis : {"CP","LeP","FS","ELMTanh","ELMSigmoid","ELMSin","ELMSwish","ELMReLU"}, optional This optional keyword argument specifies the basis functions to be used. (Default value = "CP") backend : Literal["C++", "Python"] This optional keyword sets the backend used to compute the basis functions. The C++ can be used with JIT, but can only be used for doubles. The Python backend can be used for other field types, e.g., complex numbers, but does not have JIT translations. Instead, pejit must be used to set the basis function outputs as compile time constants in order to JIT. """ # Store givens self.N = N self.deg = deg self._backend = backend if isinstance(nC, int): self.nC: npt.NDArray = onp.arange(nC, dtype=onp.int32) elif isinstance(nC, np.ndarray): self.nC: npt.NDArray = cast(npt.NDArray, nC.astype(onp.int32)) elif isinstance(nC, list): self.nC: npt.NDArray = np.array(nC, dtype=np.int32) if self.nC.shape[0] > self.deg: TFCPrint.Error("Number of basis functions is less than number of constraints!") if np.any(self.nC < 0): TFCPrint.Error( "To set nC to -1 (no constraints) either use nC = -1 or nC = 0 (i.e., use an integer not a list or array). Do not put only -1 in a list or array, this will cause issues in the C++ layer." ) self.basis = basis if x0 is None: self.x0 = 0.0 else: self.x0 = x0 if isinstance(self.x0, int): self.x0 = float(self.x0) TFCPrint.Warning("x0 is an integer. Converting to float to avoid errors down the line.") if xf is None: self.xf = 0.0 else: self.xf = xf if isinstance(self.xf, int): self.xf = float(self.xf) TFCPrint.Warning("xf is an integer. Converting to float to avoid errors down the line.") # Setup the basis function if backend == "C++": from .utils import BF elif backend == "Python": from .utils import BF_Py as BF else: TFCPrint.Error( f'The backend {backend} was specified, but can only be one of "C++" or "Python".' ) if self.basis == "CP": self.basisClass = BF.CP(self.x0, self.xf, self.nC, self.deg + 1) z0 = -1.0 zf = 1.0 elif self.basis == "LeP": self.basisClass = BF.LeP(self.x0, self.xf, self.nC, self.deg + 1) z0 = -1.0 zf = 1.0 elif self.basis == "FS": self.basisClass = BF.FS(self.x0, self.xf, self.nC, self.deg + 1) z0 = -np.pi zf = np.pi elif self.basis == "ELMReLU": self.basisClass = BF.ELMReLU(self.x0, self.xf, self.nC, self.deg + 1) z0 = 0.0 zf = 1.0 elif self.basis == "ELMSigmoid": self.basisClass = BF.ELMSigmoid(self.x0, self.xf, self.nC, self.deg + 1) z0 = 0.0 zf = 1.0 elif self.basis == "ELMTanh": self.basisClass = BF.ELMTanh(self.x0, self.xf, self.nC, self.deg + 1) z0 = 0.0 zf = 1.0 elif self.basis == "ELMSin": self.basisClass = BF.ELMSin(self.x0, self.xf, self.nC, self.deg + 1) z0 = 0.0 zf = 1.0 elif self.basis == "ELMSwish": self.basisClass = BF.ELMSwish(self.x0, self.xf, self.nC, self.deg + 1) z0 = 0.0 zf = 1.0 else: TFCPrint.Error("Invalid basis selection. Please select a valid basis") self.c = self.basisClass.c # Calculate z points and corresponding x if self.basis in ["CP", "LeP"]: n = self.N - 1 # Multiplying x0 by 0 below so the array I has the same # type as x0. I = np.linspace(0 * self.x0, n, n + 1) self.z = np.cos(np.pi * (n - I) / float(n)) self.x = (self.z - z0) / self.c + self.x0 else: self.z = np.linspace(z0, zf, self.N) self.x = (self.z - z0) / self.c + self.x0 self._SetupJax()
[docs] def H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ This function computes the basis function matrix for the points specified by x. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- H : NDArray Basis function matrix. """ return self._Hjax(x, d=0, full=full)
[docs] def dH(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """This function computes the deriative of H. See documentation of 'H' for more details. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- H : NDArray Derivative of the basis function matrix. """ return self._Hjax(x, d=1, full=full)
[docs] def d2H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """This function computes the second deriative of H. See documentation of H for more details. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- d2H : NDArray Second derivative of the basis function matrix. """ return self._Hjax(x, d=2, full=full)
[docs] def d4H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """This function computes the fourth deriative of H. See documentation of H for more details. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- d4H : NDArray Fourth derivative of the basis function matrix. """ return self._Hjax(x, d=4, full=full)
[docs] def d8H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """This function computes the eighth deriative of H. See documentation of H for more details. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- d8H : NDArray Eighth derivative of the basis function matrix. """ return self._Hjax(x, d=8, full=full)
def _SetupJax(self): """This function is used internally by TFC to setup JAX primatives and create desired behavior when taking derivatives of TFC constrained expressions.""" # Regiser XLA function if self._backend == "C++": obj = self.basisClass.xlaCapsule xlaName = "BasisFunc" + str(self.basisClass.identifier) register_ffi_target(xlaName, obj, platform="cpu", api_version=0) # Create primitives H_p = Primitive("H") def Hjax(x: JaxOrNumpyArray, d: uint = 0, full: bool = False) -> npt.NDArray: return cast(npt.NDArray, H_p.bind(x, d=d, full=full)) # Implicit translation def H_impl(x: npt.NDArray, d: uint = 0, full=False) -> npt.NDArray: return self.basisClass.H(x, d, full) H_p.def_impl(H_impl) # Abstract evaluation def H_abstract_eval(x, d: uint = 0, full: bool = False) -> core.ShapedArray: if full: dim1 = self.basisClass.m else: dim1 = self.basisClass.m - self.basisClass.numC if len(x.shape) == 0: dims = (dim1,) else: dims = (x.shape[0], dim1) return core.ShapedArray(dims, x.dtype) H_p.def_abstract_eval(H_abstract_eval) if self._backend == "C++": # XLA compilation def H_xla(ctx, x, d: uint = 0, full: bool = False): x_ir_type = ir.RankedTensorType(x.type) # x.type is already an ir.Type x_element_type = x_ir_type.element_type x_dims = x_ir_type.shape # This is a list of integers dim0 = x_dims[0] if full: dim1 = self.basisClass.m else: dim1 = self.basisClass.m - self.basisClass.numC # Define Result Types result_types = [ir.RankedTensorType.get([dim0, dim1], x_element_type)] # Call mlir.custom_call custom_call_op = mlir_int.custom_call( call_target_name=xlaName, result_types=result_types, operands=[ mlir.ir_constant(np.int32(self.basisClass.identifier)), x, mlir.ir_constant(np.int32(d)), mlir.ir_constant(np.bool(full)), mlir.ir_constant(np.int32(dim0)), mlir.ir_constant(np.int32(dim1)), ], has_side_effect=False, api_version=3, ) return custom_call_op.results mlir.register_lowering(H_p, H_xla, platform="cpu") # Define batching translation def H_batch(vec, batch, d: uint = 0, full: bool = False): return Hjax(*vec, d=d, full=full), batch[0] batching.primitive_batchers[H_p] = H_batch # Define jacobain vector product def H_jvp(arg_vals, arg_tans, d: uint = 0, full: bool = False): x = arg_vals[0] dx = arg_tans[0] if not (dx is ad.Zero): flag = onp.any(dx != 0) if flag: if len(dx.shape) == 1: out_tans = Hjax(x, d=d + 1, full=full) * onp.expand_dims(dx, 1) else: out_tans = Hjax(x, d=d + 1, full=full) * dx else: dim0 = x.shape[0] if full: dim1 = self.basisClass.m else: dim1 = self.basisClass.m - self.basisClass.numC out_tans = np.zeros((dim0, dim1)) return (Hjax(x, d=d, full=full), out_tans) ad.primitive_jvps[H_p] = H_jvp # Provide pointer for TFC class self._Hjax = Hjax
[docs] class HybridUtfc: """ This class combines TFC classes together so that multiple basis functions can be used simultaneously in the solution. Note, that this class is not yet complete. Parameters ---------- tfcClasses : list of utfc classes This list of utfc classes make up the basis functions used in the HybridUtfc class. """ def __init__(self, tfcClasses): """ This function computes the basis function matrix for the points specified by x. Parameters ---------- tfcClasses : list of utfc classes This list of utfc classes make up the basis functions used in the HybridUtfc class. """ if not all([k.N == tfcClasses[0].N for k in tfcClasses]): TFCPrint.Error("Not all TFC classes provided have the same number of points.") self._tfcClasses = tfcClasses
[docs] def H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ This function computes the basis function matrix for the points specified by x. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- H : NDArray Basis function matrix. """ return cast( npt.NDArray, np.hstack([k._Hjax(x, d=0, full=full) for j, k in enumerate(self._tfcClasses)]), )
[docs] def dH(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ This function computes the derivative of the basis function matrix for the points specified by x. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- dH : NDArray Derivative of the basis function matrix. """ return cast( npt.NDArray, np.hstack([k._Hjax(x, d=1, full=full) for j, k in enumerate(self._tfcClasses)]), )
[docs] def d2H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ This function computes the second derivative of the basis function matrix for the points specified by x. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- d2H : NDArray Second derivative of the basis function matrix. """ return cast( npt.NDArray, np.hstack([k._Hjax(x, d=2, full=full) for j, k in enumerate(self._tfcClasses)]), )
[docs] def d3H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ This function computes the third derivative of the basis function matrix for the points specified by x. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- d3H : NDArray Third derivative of the basis function matrix. """ return cast( npt.NDArray, np.hstack([k._Hjax(x, d=3, full=full) for j, k in enumerate(self._tfcClasses)]), )
[docs] def d4H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ This function computes the fourth derivative of the basis function matrix for the points specified by x. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- d4H : NDArray Fourth derivative of the basis function matrix. """ return cast( npt.NDArray, np.hstack([k._Hjax(x, d=4, full=full) for j, k in enumerate(self._tfcClasses)]), )
[docs] def d8H(self, x: JaxOrNumpyArray, full: bool = False) -> npt.NDArray: """ This function computes the eighth derivative of the basis function matrix for the points specified by x. Parameters ---------- x : JaxOrNumpyArray Points to calculate the basis functions at. full : bool, optional If true then the values specified by nC to the utfc class are ignored and all basis functions are computed. (Default value = False) Returns ------- d8H : NDArray Eighth derivative of the basis function matrix. """ return cast( npt.NDArray, np.hstack([k._Hjax(x, d=8, full=full) for j, k in enumerate(self._tfcClasses)]), )