
.. _program_listing_file_src_tfc_utils_TFCUtils.py:

Program Listing for File TFCUtils.py
====================================

|exhale_lsh| :ref:`Return to documentation for file <file_src_tfc_utils_TFCUtils.py>` (``src/tfc/utils/TFCUtils.py``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: py

   import sys
   from colorama import init as initColorama
   from colorama import Fore as fg
   from colorama import Style as style
   
   from collections import OrderedDict
   from functools import partial
   
   from jax._src.config import config
   
   config.update("jax_enable_x64", True)
   import numpy as onp
   import numpy.typing as npt
   import jax.numpy as np
   from jax import jvp, jit, lax, jacfwd, typeof
   from jax.extend import linear_util as lu
   from jax.api_util import debug_info
   from jax.tree_util import register_pytree_node, tree_map
   from jax._src.api_util import flatten_fun
   from jax._src.tree_util import tree_flatten
   from jax.core import eval_jaxpr
   from jax.interpreters.partial_eval import trace_to_jaxpr_nounits, PartialVal
   from jax.experimental import io_callback
   from typing import Any, Callable, Optional, cast, Union, overload
   from .tfc_types import uint, Literal, TypedDict, Path
   from jaxtyping import PyTree
   
   # Types that can be added to a TFCDict
   TFCDictAddable = Union[np.ndarray, dict[Any, Any], "TFCDict"]
   
   # Types that can be added to a TFCDictRobust
   TFCDictRobustAddable = Union[np.ndarray, dict[Any, Any], "TFCDictRobust"]
   
   
   class TFCPrint:
       """
       This class is used to print to the terminal in color.
       """
   
       def __init__(self):
           """
           This function is the constructor. It initializes the colorama class.
           """
           initColorama()
   
       @staticmethod
       def Error(stringIn: str):
           """
           This function prints errors. It prints the text in 'stringIn' in bright red and
           exits the program.
   
           Parameters
           ----------
           stringIn : str
               error string
           """
           print(fg.RED + style.BRIGHT + stringIn)
           print(style.RESET_ALL, end="")
           sys.exit()
   
       @staticmethod
       def Warning(stringIn: str):
           """
           This function prints warnings. It prints the text in 'stringIn' in bright yellow.
   
           Parameters
           ----------
           stringIn : str
               warning string
           """
           print(fg.YELLOW + style.BRIGHT + stringIn)
           print(style.RESET_ALL, end="")
   
   
   def egrad(g: Callable[..., Any], j: uint = 0):
       """
       This function mimics egrad from autograd.
   
       Parameters
       ----------
       g : Callable[..., Any]
           Function to take the derivative of.
   
       j : uint, optional
           Parameter with which to take the derivative with respect to. (Default value = 0)
   
       Returns
       -------
       wrapped : function
           Derivative function
       """
   
       def wrapped(*args: Any) -> Any:
           """
           Wrapper for derivative of g with respect to parameter number j.
   
           Parameters
           ----------
           *args : Any
               function arguments to g
   
           Returns
           -------
           x_bar: Any
               derivative of g with respect to parameter number j
           """
           tans = tuple(
               [
                   onp.ones(args[i].shape) if i == j else onp.zeros(args[i].shape)
                   for i in range(len(args))
               ]
           )
           _, x_bar = jvp(g, args, tans)
           return x_bar
   
       return wrapped
   
   
   @partial(partial, tree_map)
   def onesRobust(val: PyTree):
       """
       Returns ones_like val, but can handle arrays and dictionaries.
   
       Parameters
       ----------
       val : PyTree
   
       Returns
       -------
       ones_like_val : PyTree
           Pytree with the same structure as val with all elements equal to one.
   
       """
   
       # The @partial will force the PyTree to look array-like so we can use array
       # functions as done below. Using an explicit cast here so LSPs will not complain.
       val = cast(npt.NDArray, val)
   
       return onp.ones(val.shape, dtype=val.dtype)
   
   
   @partial(partial, tree_map)
   def zerosRobust(val: PyTree):
       """
       Returns zeros_like val, but can handle arrays and dictionaries.
   
       Parameters
       ----------
       val : PyTree
   
       Returns
       -------
       zeros_like_val : PyTree
           Pytree with the same structure as val with all elements equal to zero.
       """
   
       # The @partial will force the PyTree to look array-like so we can use array
       # functions as done below. Using an explicit cast here so LSPs will not complain.
       val = cast(npt.NDArray, val)
   
       return onp.zeros(val.shape, dtype=val.dtype)
   
   
   def egradRobust(g: Callable[..., Any], j: uint = 0):
       """This function mimics egrad from autograd, but can also handle dictionaries.
   
       Parameters
       ----------
       g : function
           Function to take the derivative of.
   
       j : integer, optional
           Parameter with which to take the derivative with respect to. (Default value = 0)
   
       Returns
       -------
       wrapped : function
           Derivative function
       """
       if g.__qualname__ == "jit.<locals>.f_jitted":
           g = g.__wrapped__
   
       def wrapped(*args: Any) -> Any:
           """
           Wrapper for derivative of g with respect to parameter number j.
   
           Parameters
           ----------
           *args : iterable
               function arguments to g
   
           Returns
           -------
           x_bar: array-like
               derivative of g with respect to parameter number j
           """
           tans = tuple(
               [onesRobust(args[i]) if i == j else zerosRobust(args[i]) for i in range(len(args))]
           )
           _, x_bar = jvp(g, args, tans)
           return x_bar
   
       return wrapped
   
   
   def pe(*args: Any, constant_arg_nums: list[int] = []) -> Any:
       """
       Decorator that returns a function evaluated such that the arg numbers specified in constant_arg_nums
       and all functions that utilizes only those arguments are treated as compile time constants.
   
       Parameters
       ----------
       *args : Any
           Arguments for the function that pe is applied to.
       constant_arg_nums : list[int], optional
           The arguments whose values and functions that depend only on these values should be
           treated as cached constants.
   
       Returns
       -------
       f : Any
           The new function whose constant_arg_num arguments have been removed. The jaxpr of this
           function has the constant_arg_num values and all functions that depend on those values
           cached as constants.
   
       Usage
       -----
       @pe(*args, constant_arg_nums=[0])
       def f(x,xi):
           # Function stuff here
   
       # Returns an f(xi) with x treated as constant
       """
   
       # Reorder to put knowns first, then unknowns
       order = [k for k in range(len(args))]
       for k in constant_arg_nums:
           order.insert(0, order.pop(k))
       reorder = np.argsort(np.array(order))
       dark = tuple(args[k] for k in order)
   
       # Store the removed args for later
       num_args_remove = len(constant_arg_nums)
   
       def wrapper(f_orig):
           if len(constant_arg_nums) > 0:
               # Reordering args so the ones to remove are given first
               # This will allow us to return a function that has completely removed those args
               # Moreover, we do it here so this reordering will be optimized by the compiler
               def f(*args):
                   new_args = tuple(args[k] for k in reorder)
                   return f_orig(*new_args)
   
               # Create the partial args needed by trace_to_jaxpr_nounits
               def get_arg(a, unknown):
                   if unknown:
                       return tree_flatten(
                           (
                               tree_map(lambda x: PartialVal.unknown(typeof(x).at_least_vspace()), a),
                               {},
                           )
                       )[0]
                   else:
                       return PartialVal.known(a)
   
               part_args = []
               for k, a in enumerate(dark):
                   temp = get_arg(a, k >= num_args_remove)
                   if isinstance(temp, list):
                       part_args += temp
                   else:
                       part_args.append(temp)
               part_args = tuple(part_args)
   
               # Create jaxpr
               wrap = lu.wrap_init(f, debug_info=debug_info(f.__name__, f, [], {}))
               _, in_tree = tree_flatten((dark, {}))
               wrap_flat, out_tree = flatten_fun(wrap, in_tree)
               jaxpr, _, const = trace_to_jaxpr_nounits(wrap_flat, part_args)
   
               # Create new, partially evaluated function
               if out_tree().num_leaves == 1 and out_tree().num_nodes == 1:
                   # out_tree() is PyTreeDef(*), so just return the value. Since eval_jaxpr returns a list,
                   # this is just value [0]
                   f_removed = lambda *args: eval_jaxpr(jaxpr, const, *tree_flatten((*args, {}))[0])[0]
               else:
                   # Use out_tree() to reshape the args correctly.
                   f_removed = lambda *args: out_tree().unflatten(
                       eval_jaxpr(jaxpr, const, *tree_flatten((*args, {}))[0])
                   )
               return f_removed
           else:
               return f_orig
   
       return wrapper
   
   
   def pejit(*args: Any, constant_arg_nums: list[int] = [], **kwargs) -> Any:
       """
       Works like :func:`pe <tfc.utils.TFCUtils.pe>`, but also JITs the returned function. See :func:`pe <tfc.utils.TFCUtils.pe>` for more details.
   
       Parameters:
       -----------
       *args: Any
           Arguments for the function that pe is applied to.
       constant_arg_nums: list[int], optional
           The arguments whose values and functions that depend only on these values should be
           treated as cached constants.
       **kwargs: Any
           Keyword arguments passed on to JIT.
   
       Returns:
       --------
       f: Any
           The new function whose constant_arg_num arguments have been removed. The jaxpr of this
           function has the constant_arg_num values and all functions that depend on those values
           cached, and they are treated as compile time constants.
       """
   
       def wrap(f_orig):
           return jit(pe(*args, constant_arg_nums=constant_arg_nums)(f_orig), **kwargs)
   
       return wrap
   
   
   class TFCDict(OrderedDict):
       """
       This is the TFC dictionary class. It extends an OrderedDict and
       adds a few methods that enable:
   
         - Adding dictionaries with the same keys together
         - Turning a dictionary into a 1-D array
         - Turning a 1-D array into a dictionary
       """
   
       def __init__(self, *args: Any):
           """
           Initialize TFCDict using the OrderedDict method.
   
           Parameters
           ----------
           *args : Any
               Arguments to pass to the OrderedDict
           """
   
           # Store dictionary and keep a record of the keys. Keys will stay in same
           # order, so that adding and subtracting is repeatable.
           super().__init__(*args)
           self._keys = list(self.keys())
           self._nKeys = len(self._keys)
           self.getSlices()
   
       def getSlices(self) -> None:
           """
           Function that creates slices for each of the keys in the dictionary.
           """
           if all(
               isinstance(value, np.ndarray) or isinstance(value, onp.ndarray)
               for value in self.values()
           ):
               arrLen = 0
               self._slices = [
                   slice(0, 0, 1),
               ] * self._nKeys
               start = 0
               stop = 0
               for k in range(self._nKeys):
                   start = stop
                   arrLen = self[self._keys[k]].shape[0]
                   stop = start + arrLen
                   self._slices[k] = slice(start, stop, 1)
           else:
               self._slices = [
                   None,
               ] * self._nKeys
   
       def update(self, *args: Any) -> None:
           """
           Overload the update method to update the _keys variable as well.
   
           Parameters
           ----------
           *args : Any
               Same as *args for the update method of ordered dict.
           """
           super().update(*args)
           self._keys = list(self.keys())
           self._nKeys = len(self._keys)
           self.getSlices()
   
       def toArray(self) -> np.ndarray:
           """
           Send dictionary to a flat JAX array.
   
           Returns
           -------
           np.ndarray
               This dictionary as a flat JAX array.
           """
           return cast(np.ndarray, np.hstack([k for k in self.values()]))
   
       def toDict(self, arr: np.ndarray) -> "TFCDict":
           """
           Send a flat JAX array to a TFCDict with the same keys.
   
           Parameters
           ----------
           arr : ndarray
               Flat JAX array to convert to TFCDict. Must have the same number of elements as total number of elements in the dictionary.
   
           Returns
           -------
           TFCDict
               JAX array as a TFCDict
           """
           arr = arr.flatten()
           return TFCDict(zip(self._keys, [arr[self._slices[k]] for k in range(self._nKeys)]))
   
       def block_until_ready(self) -> "TFCDict":
           """
           Mimics block_until_ready for jax arrays. Used to halt the program until the computation that created the
           dictionary is finished.
   
           Returns
           -------
           TFCDict
               This TFCDict.
           """
           self[self._keys[0]].block_until_ready()
           return self
   
       def __iadd__(self, o: TFCDictAddable) -> "TFCDict":
           """
           Used to overload "+=" for TFCDict so that 2 TFCDict's can be added together.
   
           Parameters
           ----------
           o : TFCDictAddable
               Values to add to the current dicitonary.
   
           Returns
           ----------
           self : TFCDict
               A copy of self after adding in the values from o.
           """
           if isinstance(o, dict) or (type(o) is type(self)):
               for key in self._keys:
                   self[key] += o[key]
           elif isinstance(o, np.ndarray):
               o = o.flatten()
               for k in range(self._nKeys):
                   self[self._keys[k]] += o[self._slices[k]]
           return self
   
       def __isub__(self, o: TFCDictAddable) -> "TFCDict":
           """
           Used to overload "-=" for TFCDict so that 2 TFCDict's can be subtracted.
   
           Parameters
           ----------
           o : TFCDictAddable
               Values to subtract from the current dicitonary.
   
           Returns
           -------
           self : TFCDict
               A copy of self after subtracting the values from o.
           """
           if isinstance(o, dict) or (type(o) is type(self)):
               for key in self._keys:
                   self[key] -= o[key]
           elif isinstance(o, np.ndarray):
               o = o.flatten()
               for k in range(self._nKeys):
                   self[self._keys[k]] -= o[self._slices[k]]
           return self
   
       def __add__(self, o: TFCDictAddable) -> "TFCDict":
           """
           Used to overload "+" for TFCDict so that 2 TFCDict's can be added together.
   
           Parameters
           ----------
           o : TFCDictAddable
               Values to add to the current dicitonary.
   
           Returns
           ----------
           out : TFCDict
               A TFCDict with values = self + o.
           """
           out = TFCDict(self)
           if isinstance(o, dict) or (type(o) is type(self)):
               for key in self._keys:
                   out[key] += o[key]
           elif isinstance(o, np.ndarray):
               o = o.flatten()
               for k in range(self._nKeys):
                   out[self._keys[k]] += o[self._slices[k]]
           return out
   
       def __sub__(self, o: TFCDictAddable) -> "TFCDict":
           """
           Used to overload "-" for TFCDict so that 2 TFCDict's can be subtracted.
   
           Parameters
           ----------
           o : TFCDictAddable
               Values to subtract from the current dicitonary.
   
           Returns
           ----------
           self : TFCDict
               A TFCDict with values = self - o.
           """
           out = TFCDict(self)
           if isinstance(o, dict) or (type(o) is type(self)):
               for key in self._keys:
                   out[key] -= o[key]
           elif isinstance(o, np.ndarray):
               o = o.flatten()
               for k in range(self._nKeys):
                   out[self._keys[k]] -= o[self._slices[k]]
           return out
   
   
   # Register TFCDict as a JAX type
   register_pytree_node(
       TFCDict,
       lambda x: (list(x.values()), list(x.keys())),
       lambda keys, values: TFCDict(zip(keys, values)),
   )
   
   
   class TFCDictRobust(OrderedDict):
       """This class is like the :class:`TFCDict <tfc.utils.TFCUtils.TFCDict>` class, but it handles non-flat arrays."""
   
       def __init__(self, *args: Any):
           """
           Initialize TFCDictRobust using the OrderedDict method.
   
           Parameters
           ----------
           *args : Any
               Arguments to pass to the OrderedDict
           """
   
           # Store dictionary and keep a record of the keys. Keys will stay in same
           # order, so that adding and subtracting is repeatable.
           super().__init__(*args)
           self._keys = list(self.keys())
           self._nKeys = len(self._keys)
           self.getSlices()
   
       def getSlices(self) -> None:
           """
           Function that creates slices for each of the keys in the dictionary.
           """
           if all(isinstance(value, np.ndarray) for value in self.values()):
               arrLen = 0
               self._slices = [
                   slice(0, 0, 1),
               ] * self._nKeys
               start = 0
               stop = 0
               for k in range(self._nKeys):
                   start = stop
                   arrLen = self[self._keys[k]].flatten().shape[0]
                   stop = start + arrLen
                   self._slices[k] = slice(start, stop, 1)
           else:
               self._slices = [
                   None,
               ] * self._nKeys
   
       def update(self, *args: Any) -> None:
           """
           Overload the update method to update the _keys variable as well.
   
           Parameters
           ----------
           *args : Any
               Same as *args for the update method on ordered dict.
           """
           super().update(*args)
           self._keys = list(self.keys())
           self._nKeys = len(self._keys)
           self.getSlices()
   
       def toArray(self) -> np.ndarray:
           """
           Send dictionary to a flat JAX array.
   
           Returns
           -------
           np.ndarray
               This dictionary as a flat JAX array.
           """
           return cast(np.ndarray, np.hstack([k.flatten() for k in self.values()]))
   
       def toDict(self, arr: np.ndarray) -> "TFCDictRobust":
           """
           Send a flat JAX array to a TFCDictRobust with the same keys.
   
           Parameters
           ----------
           arr : np.ndarray
               Flat JAX array to convert to TFCDictRobust. Must have the same number of elements as total number of elements in the dictionary.
   
           Returns
           -------
           TFCDictRobust
               JAX array as a TFCDictRobust
           """
           arr = arr.flatten()
           return TFCDictRobust(
               zip(
                   self._keys,
                   [
                       arr[self._slices[k]].reshape(self[self._keys[k]].shape)
                       for k in range(self._nKeys)
                   ],
               )
           )
   
       def block_until_ready(self) -> "TFCDictRobust":
           """
           Mimics block_until_ready for jax arrays. Used to halt the program until the computation that created the
           dictionary is finished.
   
           Returns
           -------
           TFCDictRobust
               This TFCDictRobust
           """
           self[self._keys[0]].block_until_ready()
           return self
   
       def __iadd__(self, o: TFCDictRobustAddable) -> "TFCDictRobust":
           """
           Used to overload "+=" for TFCDictRobust so that 2 TFCDictRobust's can be added together.
   
           Parameters
           ----------
           o : TFCDictRobustAddable
               Values to add to the current dicitonary.
   
           Returns
           ----------
           self : TFCDictRobust
               A copy of self after adding in the values from o.
           """
           if isinstance(o, dict) or (type(o) is type(self)):
               for key in self._keys:
                   self[key] += o[key]
           elif isinstance(o, np.ndarray):
               o = o.flatten()
               for k in range(self._nKeys):
                   self[self._keys[k]] += o[self._slices[k]].reshape(self[self._keys[k]].shape)
           return self
   
       def __isub__(self, o: TFCDictRobustAddable) -> "TFCDictRobust":
           """
           Used to overload "-=" for TFCDictRobust so that 2 TFCDictRobust's can be subtracted.
   
           Parameters
           ----------
           o : TFCDictRobustAddable
               Values to subtract from the current dicitonary.
   
           Returns
           ----------
           self : TFCDictRobust
               A copy of self after subtracting the values from o.
           """
           if isinstance(o, dict) or (type(o) is type(self)):
               for key in self._keys:
                   self[key] -= o[key]
           elif isinstance(o, np.ndarray):
               o = o.flatten()
               for k in range(self._nKeys):
                   self[self._keys[k]] -= o[self._slices[k]].reshape(self[self._keys[k]].shape)
           return self
   
       def __add__(self, o: TFCDictRobustAddable) -> "TFCDictRobust":
           """
           Used to overload "+" for TFCDictRobust so that 2 TFCDictRobust's can be added together.
   
           Parameters
           ----------
           o : TFCDictRobustAddable
               Values to add to the current dicitonary.
   
           Returns
           ----------
           out : TFCDictRobust
               A TFCDictRobust with values = self + o.
           """
           out = TFCDictRobust(self)
           if isinstance(o, dict) or (type(o) is type(self)):
               for key in self._keys:
                   out[key] += o[key]
           elif isinstance(o, np.ndarray):
               o = o.flatten()
               for k in range(self._nKeys):
                   out[self._keys[k]] += o[self._slices[k]].reshape(self[self._keys[k]].shape)
           return out
   
       def __sub__(self, o: TFCDictRobustAddable) -> "TFCDictRobust":
           """
           Used to overload "-" for TFCDictRobust so that 2 TFCDictRobust's can be subtracted.
   
           Parameters
           ----------
           o : TFCDictRobustAddable
               Values to subtract from the current dicitonary.
   
           Returns
           -------
           self : TFCDictRobust
               A TFCDictRobust with values = self - o.
           """
           out = TFCDictRobust(self)
           if isinstance(o, dict) or (type(o) is type(self)):
               for key in self._keys:
                   out[key] -= o[key]
           elif isinstance(o, np.ndarray):
               o = o.flatten()
               for k in range(self._nKeys):
                   out[self._keys[k]] -= o[self._slices[k]].reshape(self[self._keys[k]].shape)
           return out
   
   
   # Register TFCDictRobust as a JAX type
   register_pytree_node(
       TFCDictRobust,
       lambda x: (list(x.values()), list(x.keys())),
       lambda keys, values: TFCDictRobust(zip(keys, values)),
   )
   
   
   @overload
   def LS(
       zXi: PyTree,
       res: Callable,
       *args: Any,
       constant_arg_nums: list[int] = [],
       J: Optional[Callable[..., np.ndarray]] = None,
       method: Literal["pinv", "lstsq"] = "pinv",
       timer: Literal[False] = False,
       timerType: str = "process_time",
       holomorphic: bool = False,
   ) -> PyTree: ...
   @overload
   def LS(
       zXi: PyTree,
       res: Callable,
       *args: Any,
       constant_arg_nums: list[int] = [],
       J: Optional[Callable[..., np.ndarray]] = None,
       method: Literal["pinv", "lstsq"] = "pinv",
       timer: Literal[True] = True,
       timerType: str = "process_time",
       holomorphic: bool = False,
   ) -> tuple[PyTree, float]: ...
   def LS(
       zXi: PyTree,
       res: Callable,
       *args: Any,
       constant_arg_nums: list[int] = [],
       J: Optional[Callable[..., np.ndarray]] = None,
       method: Literal["pinv", "lstsq"] = "pinv",
       timer: bool = False,
       timerType: str = "process_time",
       holomorphic: bool = False,
   ) -> PyTree | tuple[PyTree, float]:
       """
       JITed least squares.
       This function takes in an initial guess of zeros, zXi, and a residual function, res, and
       linear least squares to minimize the res function using the parameters
       xi.
   
       Parameters
       ----------
       zXi : PyTree
           Unknown parameters to be found using least-squares.
   
       res : Callable
           Residual function (also known as the loss function) with signature res(xi: PyTree, *args:Any, **kwargs:Any).
           Note, the first argument does not need to be named xi, this is just illustrative.
   
       *args : Any
           Any additional arguments taken by res other than the first PyTree argument.
   
       constant_arg_nums: list[int], optional
           These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.
   
       J : Optional[Callable[...,np.ndarray]]
            User specified Jacobian function. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)
   
       method : Literal["pinv","lstsq"], optional
            Method for least-squares inversion. (Default value = "pinv")
            * pinv - Use np.linalg.pinv
            * lstsq - Use np.linalg.lstsq
   
       timer : bool, optional
            Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
            As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.
   
       timerType : str, optional
            Any timer from the time module. (Default value = "process_time")
   
       holomorphic : bool, optional
            Indicates whether residual function is promised to be holomorphic. (Default value = False)
   
       Returns
       -------
       xi : pytree or array-like
            Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.
   
       time : float
            Computation time as calculated by timerType specified. This output is only returned if timer = True.
       """
   
       if isinstance(zXi, TFCDict) or isinstance(zXi, TFCDictRobust):
           dictFlag = True
       else:
           dictFlag = False
   
       if J is None:
           if dictFlag:
               if isinstance(zXi, TFCDictRobust):
   
                   def J(xi, *args):
                       jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
                       return np.hstack(
                           [
                               jacob[k].reshape(jacob[k].shape[0], onp.prod(onp.array(xi[k].shape)))
                               for k in xi.keys()
                           ]
                       )
   
               else:
   
                   def J(xi, *args):
                       jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
                       return np.hstack([jacob[k] for k in xi.keys()])
   
           else:
               J = lambda xi, *args: jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
   
       if method == "pinv":
           ls = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), -res(xi, *args))
       elif method == "lstsq":
           ls = lambda xi, *args: np.linalg.lstsq(J(xi, *args), -res(xi, *args), rcond=None)[0]
       else:
           TFCPrint.Error("The method entered is not valid. Please enter a valid method.")
   
       if constant_arg_nums:
           # Make arguments constant if desired
           ls = pe(zXi, *args, constant_arg_nums=constant_arg_nums)(ls)
   
           args: list[Any] = list(args)
           constant_arg_nums.sort()
           constant_arg_nums.reverse()
           for k in constant_arg_nums:
               args.pop(k - 1)
   
       ls = jit(ls)
       zXi = zerosRobust(zXi)
   
       if timer:
           import time
   
           timer_f: Callable[[], float] = getattr(time, timerType)
           ls(zXi, *args).block_until_ready()
   
           start = timer_f()
           xi = ls(zXi, *args).block_until_ready()
           stop = timer_f()
           zXi += xi
   
           return zXi, stop - start
       else:
           zXi += ls(zXi, *args)
           return zXi
   
   
   class LsClass:
       """
       JITed linear least-squares class.
       Like the :func:`LS <tfc.utils.TFCUtils.LS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing.
       See :func:`LS <tfc.utils.TFCUtils.LS>` for more details.
       """
   
       def __init__(
           self,
           zXi: PyTree,
           res: Callable,
           *args: Any,
           constant_arg_nums: list[int] = [],
           J: Optional[Callable[..., np.ndarray]] = None,
           method: Literal["pinv", "lstsq"] = "pinv",
           timer: bool = False,
           timerType: str = "process_time",
           holomorphic: bool = False,
       ) -> None:
           """
           Initialization function. Creates the JIT-ed least-squares function.
   
           Parameters
           ----------
           zXi : PyTree
               Unknown parameters to be found using least-squares.
   
           res : Callable
               Residual function (also known as the loss function) with signature res(xi: PyTree, *args:Any, **kwargs:Any).
               Note, the first argument does not need to be named xi, this is just illustrative.
   
           *args : Any
               Any additional arguments taken by res other than the first PyTree argument.
   
           J : Optional[Callable[...,np.ndarray]]
                User specified Jacobian function. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)
   
           constant_arg_nums: list[int], optional
               These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.
   
           method : Literal["pinv","lstsq"], optional
                Method for least-squares inversion. (Default value = "pinv")
                * pinv - Use np.linalg.pinv
                * lstsq - Use np.linalg.lstsq
   
           timer : bool, optional
                Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
                As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.
   
           timerType : str, optional
                Any timer from the time module. (Default value = "process_time")
   
           holomorphic : bool, optional
                Indicates whether residual function is promised to be holomorphic. (Default value = False)
           """
   
           self.timerType = timerType
           self.timer = timer
           self.holomorphic = holomorphic
   
           if isinstance(zXi, TFCDict) or isinstance(zXi, TFCDictRobust):
               dictFlag = True
           else:
               dictFlag = False
   
           if J is None:
               if dictFlag:
                   if isinstance(zXi, TFCDictRobust):
   
                       def J(xi, *args):
                           jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
                           return np.hstack(
                               [
                                   jacob[k].reshape(
                                       jacob[k].shape[0], onp.prod(onp.array(xi[k].shape))
                                   )
                                   for k in xi.keys()
                               ]
                           )
   
                   else:
   
                       def J(xi, *args):
                           jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
                           return np.hstack([jacob[k] for k in xi.keys()])
   
               else:
                   J = lambda xi, *args: jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
   
           if method == "pinv":
               ls = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), -res(xi, *args))
           elif method == "lstsq":
               ls = lambda xi, *args: np.linalg.lstsq(J(xi, *args), -res(xi, *args), rcond=None)[0]
   
           else:
               TFCPrint.Error("The method entered is not valid. Please enter a valid method.")
   
           if constant_arg_nums:
               # Make arguments constant if desired
               ls = pe(zXi, *args, constant_arg_nums=constant_arg_nums)(ls)
   
               args: list[Any] = list(args)
               constant_arg_nums.sort()
               constant_arg_nums.reverse()
               for k in constant_arg_nums:
                   args.pop(k - 1)
   
           self._ls = jit(ls)
   
           self._compiled = False
   
       def run(self, zXi: PyTree, *args: Any) -> PyTree | tuple[PyTree, float]:
           """
           Runs the JIT-ed least-squares function and times it if desired.
   
           Parameters
           ----------
           zXi : PyTree
               Unknown parameters to be found using least-squares.
   
           *args : Any
               Any additional arguments taken by res other than xi.
   
           Returns
           -------
           xi : PyTree
                Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.
   
           time : float, optional
                Computation time as calculated by timerType specified. This output is only returned if timer = True.
   
           """
   
           if self.timer:
               import time
   
               timer = getattr(time, self.timerType)
   
               if not self._compiled:
                   self._ls(zXi, *args).block_until_ready()
                   self._compiled = True
   
               start = timer()
               xi = self._ls(zXi, *args).block_until_ready()
               stop = timer()
               zXi += xi
   
               return zXi, stop - start
   
           else:
               zXi += ls(zXi, *args)
   
               self._compiled = True
   
               return zXi
   
   
   def nlls_id_print(it: int, x, end: str = "\n"):
       print("Iteration: {0}\tmax(abs(res)): {1}".format(it, x), end=end)
   
   
   @overload
   def NLLS(
       xiInit: PyTree,
       res: Callable,
       *args: Any,
       constant_arg_nums: list[int] = [],
       J: Optional[Callable[..., np.ndarray]] = None,
       cond: Optional[Callable[[PyTree], bool]] = None,
       body: Optional[Callable[[PyTree], PyTree]] = None,
       tol: float = 1e-13,
       maxIter: uint = 50,
       method: Literal["pinv", "lstsq"] = "pinv",
       timer: Literal[False] = False,
       printOut: bool = False,
       printOutEnd: str = "\n",
       timerType: str = "process_time",
       holomorphic: bool = False,
   ) -> tuple[PyTree, int]: ...
   @overload
   def NLLS(
       xiInit: PyTree,
       res: Callable,
       *args: Any,
       constant_arg_nums: list[int] = [],
       J: Optional[Callable[..., np.ndarray]] = None,
       cond: Optional[Callable[[PyTree], bool]] = None,
       body: Optional[Callable[[PyTree], PyTree]] = None,
       tol: float = 1e-13,
       maxIter: uint = 50,
       method: Literal["pinv", "lstsq"] = "pinv",
       timer: Literal[True] = True,
       printOut: bool = False,
       printOutEnd: str = "\n",
       timerType: str = "process_time",
       holomorphic: bool = False,
   ) -> tuple[PyTree, int, float]: ...
   def NLLS(
       xiInit: PyTree,
       res: Callable,
       *args: Any,
       constant_arg_nums: list[int] = [],
       J: Optional[Callable[..., np.ndarray]] = None,
       cond: Optional[Callable[[PyTree], bool]] = None,
       body: Optional[Callable[[PyTree], PyTree]] = None,
       tol: float = 1e-13,
       maxIter: uint = 50,
       method: Literal["pinv", "lstsq"] = "pinv",
       timer: bool = False,
       printOut: bool = False,
       printOutEnd: str = "\n",
       timerType: str = "process_time",
       holomorphic: bool = False,
   ) -> tuple[PyTree, int] | tuple[PyTree, int, float]:
       """
       JIT-ed non-linear least squares.
       This function takes in an initial guess, xiInit (initial values of xi), and a residual function, res, and
       performs a nonlinear least squares to minimize the res function using the parameters
       xi. The conditions on terminating the nonlinear least-squares are:
       1. max(abs(res)) < tol
       2. max(abs(dxi)) < tol, where dxi is the change in xi from the last iteration.
       3. Number of iterations > maxIter.
   
       Parameters
       ----------
       xiInit : pytree or array-like
           Initial guess for the unkown parameters.
   
       res : function
           Residual function (also known as the loss function) with signature res(xi,*args).
   
       *args : iterable
           Any additional arguments taken by res other than xi.
   
       constant_arg_nums: list[int], optional
           These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.
   
       J : function
            User specified Jacobian. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)
   
       cond : Optional[Callable[[PyTree],bool]]
            User specified condition function. If None, then the default cond function is used which checks the three termination criteria
            provided in the class description. (Default value = None)
   
       body : Optional[Callable[[PyTree],PyTree]]
            User specified while-loop body function. If None, then use the default body function which updates xi using a NLLS interation and the method provided.
            (Default value = None)
   
       tol : float
            Tolerance used in the default termination criteria: see class description for more details. (Default value = 1e-13)
   
       maxIter : int, optional
            Maximum number of iterations. (Default value = 50)
   
       method : Literal["pinv","lstsq"], optional
            Method for least-squares inversion. (Default value = "pinv")
            * pinv - Use np.linalg.pinv
            * lstsq - Use np.linalg.lstsq
   
       timer : bool, optional
            Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
            As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.
   
       printOut : bool, optional
            Controls whether the NLLS prints out information each interaton or not. The printout consists of the iteration and max(abs(res)) at each iteration. (Default value = False)
   
       printOutEnd : str, optional
            Value of keyword argument end passed to the print statement used in printOut. (Default value = "\\\\n")
   
       timerType : str, optional
            Any timer from the time module. (Default value = "process_time")
   
       holomorphic : bool, optional
            Indicates whether residual function is promised to be holomorphic. (Default value = False)
   
       Returns
       -------
       xi : PyTree
            Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.
   
       it : int
            Number of NLLS iterations performed.
   
       time : float
            Computation time as calculated by timerType specified. This output is only returned if timer = True.
       """
   
       if timer and printOut:
           TFCPrint.Warning(
               "Warning, you have both the timer and printer on in the nonlinear least-squares.\nThe time will be longer than optimal due to the printout."
           )
   
       if isinstance(xiInit, TFCDict) or isinstance(xiInit, TFCDictRobust):
           dictFlag = True
       else:
           dictFlag = False
   
       def cond(val):
           return np.all(
               np.array(
                   [
                       np.max(np.abs(res(val["xi"], *val["args"]))) > tol,
                       val["it"] < maxIter,
                       np.max(np.abs(val["dxi"])) > tol,
                   ]
               )
           )
   
       if J is None:
           if dictFlag:
               if isinstance(xiInit, TFCDictRobust):
   
                   def J(xi, *args):
                       jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
                       return np.hstack(
                           [
                               jacob[k].reshape(jacob[k].shape[0], onp.prod(onp.array(xi[k].shape)))
                               for k in xi.keys()
                           ]
                       )
   
               else:
   
                   def J(xi, *args):
                       jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
                       return np.hstack([jacob[k] for k in xi.keys()])
   
           else:
               J = lambda xi, *args: jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
   
       if method == "pinv":
           LS = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), res(xi, *args))
       elif method == "lstsq":
           LS = lambda xi, *args: np.linalg.lstsq(J(xi, *args), res(xi, *args), rcond=None)[0]
       else:
           TFCPrint.Error("The method entered is not valid. Please enter a valid method.")
   
       if constant_arg_nums:
           # Make arguments constant if desired
           LS = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(LS)
           res = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(res)
   
           args: list[Any] = list(args)
           constant_arg_nums.sort()
           constant_arg_nums.reverse()
           for k in constant_arg_nums:
               args.pop(k - 1)
   
       if body is None:
           if printOut:
   
               def body(val):
                   val["dxi"] = LS(val["xi"], *val["args"])
                   val["xi"] -= val["dxi"]
                   io_callback(
                       partial(nlls_id_print, end=printOutEnd),
                       None,
                       val["it"],
                       np.max(np.abs(res(val["xi"], *val["args"]))),
                   )
                   val["it"] += 1
                   return val
   
           else:
   
               def body(val):
                   val["dxi"] = LS(val["xi"], *val["args"])
                   val["xi"] -= val["dxi"]
                   val["it"] += 1
                   return val
   
       nlls = jit(lambda val: lax.while_loop(cond, body, val))
   
       if dictFlag:
           dxi = np.ones_like(cast(TFCDict | TFCDictRobust, xiInit).toArray())
       else:
           dxi = np.ones_like(xiInit)
   
       if timer:
           import time
   
           timer_f: Callable[[], float] = getattr(time, timerType)
           val = {"xi": xiInit, "dxi": dxi, "it": maxIter - 1, "args": args}
           nlls(val)["dxi"].block_until_ready()
   
           val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args}
   
           start = timer_f()
           val = nlls(val)
           val["dxi"].block_until_ready()
           stop = timer_f()
   
           return val["xi"], val["it"], stop - start
       else:
           val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args}
           val = nlls(val)
           return val["xi"], val["it"]
   
   
   class NllsClass:
       """
       JITed nonlinear least squares class.
       Like the :func:`NLLS <tfc.utils.TFCUtils.NLLS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing
       """
   
       def __init__(
           self,
           xiInit: PyTree,
           res: Callable,
           *args: Any,
           constant_arg_nums: list[int] = [],
           J: Optional[Callable[..., np.ndarray]] = None,
           cond: Optional[Callable[[PyTree], bool]] = None,
           body: Optional[Callable[[PyTree], PyTree]] = None,
           tol: float = 1e-13,
           maxIter: uint = 50,
           method: Literal["pinv", "lstsq"] = "pinv",
           timer: bool = False,
           printOut: bool = False,
           printOutEnd: str = "\n",
           timerType: str = "process_time",
           holomorphic: bool = False,
       ) -> None:
           """
           Initialization function. Creates the JIT-ed nonlinear least-squares function.
   
           Parameters
           ----------
           xiInit : pytree or array-like
               Initial guess for the unkown parameters.
   
           res : function
               Residual function (also known as the loss function) with signature res(xi,*args).
   
           *args : iterable
               Any additional arguments taken by res other than xi.
   
           constant_arg_nums: list[int], optional
               These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.
   
           J : function
                User specified Jacobian. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)
   
           cond : Optional[Callable[[PyTree],bool]]
                User specified condition function. If None, then the default cond function is used which checks the three termination criteria
                provided in the class description. (Default value = None)
   
           body : Optional[Callable[[PyTree],PyTree]]
                User specified while-loop body function. If None, then use the default body function which updates xi using a NLLS interation and the method provided.
                (Default value = None)
   
           tol : float
                Tolerance used in the default termination criteria: see class description for more details. (Default value = 1e-13)
   
           maxIter : int, optional
                Maximum number of iterations. (Default value = 50)
   
           method : Literal["pinv","lstsq"], optional
                Method for least-squares inversion. (Default value = "pinv")
                * pinv - Use np.linalg.pinv
                * lstsq - Use np.linalg.lstsq
   
           timer : bool, optional
                Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
                As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.
   
           printOut : bool, optional
                Controls whether the NLLS prints out information each interaton or not. The printout consists of the iteration and max(abs(res)) at each iteration. (Default value = False)
   
           printOutEnd : str, optional
                Value of keyword argument end passed to the print statement used in printOut. (Default value = "\\\\n")
   
           timerType : str, optional
                Any timer from the time module. (Default value = "process_time")
   
           holomorphic : bool, optional
                Indicates whether residual function is promised to be holomorphic. (Default value = False)
           """
   
           self.timerType = timerType
           self.timer = timer
           self._maxIter = maxIter
           self.holomorphic = holomorphic
   
           if timer and printOut:
               TFCPrint.Warning(
                   "Warning, you have both the timer and printer on in the nonlinear least-squares.\nThe time will be longer than optimal due to the printout."
               )
   
           if isinstance(xiInit, TFCDict) or isinstance(xiInit, TFCDictRobust):
               self._dictFlag = True
           else:
               self._dictFlag = False
   
           def cond(val):
               return np.all(
                   np.array(
                       [
                           np.max(np.abs(res(val["xi"], *val["args"]))) > tol,
                           val["it"] < maxIter,
                           np.max(np.abs(val["dxi"])) > tol,
                       ]
                   )
               )
   
           if J is None:
               if self._dictFlag:
                   if isinstance(xiInit, TFCDictRobust):
   
                       def J(xi, *args):
                           jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
                           return np.hstack(
                               [
                                   jacob[k].reshape(
                                       jacob[k].shape[0], onp.prod(onp.array(xi[k].shape))
                                   )
                                   for k in xi.keys()
                               ]
                           )
   
                   else:
   
                       def J(xi, *args):
                           jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
                           return np.hstack([jacob[k] for k in xi.keys()])
   
               else:
                   J = lambda xi, *args: jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
   
           if method == "pinv":
               LS = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), res(xi, *args))
           elif method == "lstsq":
               LS = lambda xi, *args: np.linalg.lstsq(J(xi, *args), res(xi, *args), rcond=None)[0]
           else:
               TFCPrint.Error("The method entered is not valid. Please enter a valid method.")
   
           if constant_arg_nums:
               # Make arguments constant if desired
               LS = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(LS)
               res = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(res)
   
               args: list[Any] = list(args)
               constant_arg_nums.sort()
               constant_arg_nums.reverse()
               for k in constant_arg_nums:
                   args.pop(k - 1)
   
           if body is None:
               if printOut:
   
                   def body(val):
                       val["dxi"] = LS(val["xi"], *val["args"])
                       val["xi"] -= val["dxi"]
                       io_callback(
                           partial(nlls_id_print, end=printOutEnd),
                           None,
                           val["it"],
                           np.max(np.abs(res(val["xi"], *val["args"]))),
                       )
                       val["it"] += 1
                       return val
   
               else:
   
                   def body(val):
                       val["dxi"] = LS(val["xi"], *val["args"])
                       val["xi"] -= val["dxi"]
                       val["it"] += 1
                       return val
   
           self._nlls = jit(lambda val: lax.while_loop(cond, body, val))
           self._compiled = False
   
       def run(self, xiInit: PyTree, *args: Any) -> tuple[PyTree, int] | tuple[PyTree, int, float]:
           """Runs the JIT-ed nonlinear least-squares function and times it if desired.
   
           Parameters
           ----------
           xiInit : PyTree
               Initial guess for the unkown parameters.
   
           *args : Any
               Any additional arguments taken by res other than xi.
   
           Returns
           -------
           xi : PyTree
                Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.
   
           it : int
                Number of NLLS iterations performed..
   
           time : float
                Computation time as calculated by timerType specified. This output is only returned if timer = True.
   
           """
   
           if self._dictFlag:
               dxi = np.ones_like(cast(TFCDict | TFCDictRobust, xiInit).toArray())
           else:
               dxi = np.ones_like(xiInit)
   
           if self.timer:
               import time
   
               timer_f: Callable[[], float] = getattr(time, self.timerType)
   
               if not self._compiled:
                   val = {"xi": xiInit, "dxi": dxi, "it": self._maxIter - 1, "args": args}
                   self._nlls(val)["dxi"].block_until_ready()
                   self._compiled = True
   
               val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args}
   
               start = timer_f()
               val = self._nlls(val)
               val["dxi"].block_until_ready()
               stop = timer_f()
   
               return val["xi"], val["it"], stop - start
   
           else:
               val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args}
               val = self._nlls(val)
   
               self._compiled = True
   
               return val["xi"], val["it"]
   
   
   class ComponentConstraintDict(TypedDict):
       name: str
       node0: str
       node1: str
   
   
   class ComponentConstraintGraph:
       """
       Creates a graph of all valid ways in which component constraints can be embedded.
       """
   
       def __init__(self, N: list[str], E: list[ComponentConstraintDict]) -> None:
           """
           Class constructor.
   
           Parameters
           ----------
           N : list[str]
               A list of strings that specify the node names. These node names typically coincide with
               the names of the dependent variables.
           E : list[ComponentConstraintDict]
               The ComponentConstraintDict is a dictionary with the following fields:
               * name - Name of the component constraint.
               * node0 - The name of one of the nodes that makes up the component constraint.  Must correspond with an element of the list given in N.
               * node1 - The name of one of the nodes that makes up the component constraint.  Must correspond with an element of the list given in N.
           """
   
           # Check that all edges are connected to valid nodes
           self.nNodes = len(N)
           self.nEdges = len(E)
           for k in range(self.nEdges):
               if not (E[k]["node0"] in N and E[k]["node1"] in N):
                   TFCPrint.Error(
                       "Error either "
                       + E[k]["node0"]
                       + " or "
                       + E[k]["node1"]
                       + " is not a valid node. Make sure they appear in the nodes list."
                   )
   
           # Create all possible source/target pairs. This tells whether node0 is the target or source, node1 will be the opposite.
           import itertools
   
           self.targets = list(itertools.product([0, 1], repeat=self.nEdges))
   
           # Find all targets that are valid trees
           self.goodTargets = []
           for j in range(len(self.targets)):
               adj = onp.zeros((self.nNodes, self.nNodes))
               for k in range(self.nNodes):
                   kNode = N[k]
                   for g in range(self.nEdges):
                       if E[g]["node0"] == kNode:
                           if self.targets[j][g]:
                               adj[N.index(E[g]["node1"]), N.index(E[g]["node0"])] = 1.0
                       elif E[g]["node1"] == kNode:
                           if not self.targets[j][g]:
                               adj[N.index(E[g]["node0"]), N.index(E[g]["node1"])] = 1.0
               if np.all(np.linalg.eigvals(adj) == 0.0):
                   self.goodTargets.append(j)
   
           # Save nodes and edges for use later
           self.N = N
           self.E = E
   
       def SaveGraphs(self, outputDir: Path, allGraphs: bool = False, savePDFs: bool = False) -> None:
           """
           Saves the graphs.
           The graphs are saved in a clickable HTML structure. They can also be saved as PDFs.
   
           Parameters
           ----------
           outputDir : Path
               Output directory to save in.
   
           allGraphs : bool, optional
                Boolean that conrols whether all graphs are saved or just valid graphs. (Default value = False)
   
           savePDFs : bool, optional
                Boolean that controls whether the graphs are also saved as PDFs. (Default value = False)
           """
           import os
           from .Html import HTML, Dot
   
           if allGraphs:
               targets = self.targets
           else:
               targets = [self.targets[k] for k in self.goodTargets]
   
           n = len(targets)
   
           #: Create the main dot file
           mainDot = Dot(os.path.join(outputDir, "dotFiles", "main"), "main")
           mainDot.dot.node_attr.update(shape="box")
           mainDot.dot.edge_attr.update(style="invis")
           treeCnt = 0
           for j in range(int(np.ceil(n / 5))):
               if j != 0:
                   mainDot.dot.edge("tree" + str((j - 1) * 5), "tree" + str(j * 5))
               with mainDot.dot.subgraph(name="subgraph" + str(j)) as c:
                   c.attr(rank="same")
                   for k in range(min(5, n - j * 5)):
                       c.node(
                           "tree" + str(treeCnt),
                           "Tree " + str(treeCnt),
                           href=os.path.join("htmlFiles", "tree" + str(treeCnt) + ".html"),
                       )
                       treeCnt += 1
   
           mainDot.Render()
   
           #: Create the main file HTML
           mainHtml = HTML(os.path.join(outputDir, "main.html"))
           with mainHtml.tag("html"):
               with mainHtml.tag("body"):
                   with mainHtml.tag("style"):
                       mainHtml.doc.asis(mainHtml.centerClass)
                   mainHtml.doc.stag(
                       "img", src=os.path.join("dotFiles", "main.svg"), usemap="#main", klass="center"
                   )
                   mainHtml.doc.asis(
                       mainHtml.ReadFile(os.path.join(outputDir, "dotFiles", "main.cmapx"))
                   )
           mainHtml.WriteFile()
   
           #: Create the tree dot files
           for k in range(n):
               treeDot = Dot(os.path.join(outputDir, "dotFiles", "tree" + str(k)), "tree" + str(k))
               treeDot.dot.attr(bgcolor="transparent")
               treeDot.dot.node_attr.update(shape="box")
               for j in range(self.nNodes):
                   treeDot.dot.node(self.N[j], self.N[j])
               for j in range(self.nEdges):
                   if not targets[k][j]:
                       treeDot.dot.edge(
                           self.E[j]["node0"], self.E[j]["node1"], label=self.E[j]["name"]
                       )
                   else:
                       treeDot.dot.edge(
                           self.E[j]["node1"], self.E[j]["node0"], label=self.E[j]["name"]
                       )
   
               if savePDFs:
                   treeDot.Render(formats=["cmapx", "svg", "pdf"])
               else:
                   treeDot.Render()
   
           #: Create the tree HTML files
           for k in range(n):
               treeHtml = HTML(os.path.join(outputDir, "htmlFiles", "tree" + str(k) + ".html"))
               with treeHtml.tag("html"):
                   with treeHtml.tag("body"):
                       with treeHtml.tag("style"):
                           treeHtml.doc.asis(treeHtml.centerClass)
                       treeHtml.doc.stag(
                           "img",
                           src=os.path.join("..", "dotFiles", "tree" + str(k) + ".svg"),
                           usemap="#tree" + str(k),
                           klass="center",
                       )
                       treeHtml.doc.asis(
                           treeHtml.ReadFile(
                               os.path.join(outputDir, "dotFiles", "tree" + str(k) + ".cmapx")
                           )
                       )
               treeHtml.WriteFile()
   
   
   def ScaledQrLs(A: np.ndarray, B: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
       """This function performs least-squares using a scaled QR method.
   
       Parameters
       ----------
       A : np.ndarray
           A matrix in A*x = B.
   
       B : np.ndarray
           B matrix in A*x = B.
   
       Returns
       -------
       x : np.ndarray
           Solution to A*x = B solved using a scaled QR method.
   
       cn : np.ndarray
           Condition number.
       """
       S = 1.0 / np.sqrt(np.sum(A * A, 0))
       S = np.reshape(S, (A.shape[1],))
       q, r = np.linalg.qr(A.dot(np.diag(S)))
       x: np.ndarray = S * np.linalg.multi_dot([_MatPinv(r), q.T, B])
       cn: np.ndarray = cast(np.ndarray, np.linalg.cond(r))
       return x, cn
   
   
   def _MatPinv(A: np.ndarray) -> np.ndarray:
       """This function is used to better replicate MATLAB's pseudo-inverse.
   
       Parameters
       ----------
       A : np.ndarray
           Matrix to be inverted.
   
       Returns
       -------
       Ainv : np.ndarray
           Inverse of A.
       """
       rcond = onp.max(A.shape) * onp.spacing(np.linalg.norm(A, ord=2))
       return np.linalg.pinv(A, rcond=rcond)
   
   
   def step(x: np.ndarray) -> np.ndarray:
       """
       This is the unit step function, but the deriative is defined and equal to 0 at every point.
   
       Parameters
       ----------
       x : np.ndarray
           Array to apply step to.
   
   
       Returns
       -------
       step_x : np.ndarray
           step(x)
       """
       return np.heaviside(x, 0)
   
   
   def criuCheckpoint(dir: str = "criu_checkpoint", user_mode: bool = False):
       """
       Use CRIU to create a checkpoint of your program.
       WARNING: You must ensure no external sockets are used, i.e., via matplotlib.
       WARNING: GPU memory cannot be mapped for all GPUs.
       WARNING: user_mode is a work in progress and is not full featured yet.
   
       Parameters
       ----------
       dir : str
           Directory where the CRIU checkpoint will be created.
       user_mode : bool
           Whether to use user mode or not. (Default value = False)
       """
   
       import os
       from pathlib import Path
   
       path = Path(dir)
       pid = os.getpid()
   
       # Create path if it does not exist
       if not os.path.exists(path):
           os.mkdir(path)
   
       # Run the command and print how to restart
       if user_mode:
           os.system(
               f"(criu dump --unpriviledged -t {pid} -D {str(path.absolute())} --shell-job --leave-running &) &"
           )
           print(
               f"Creating a checkpoint. To restart run: sudo criu restore -D {str(path.absolute())} --shell-job"
           )
       else:
           os.system(
               f"(sudo criu dump -t {pid} -D {str(path.absolute())} --shell-job --leave-running &) &"
           )
           print(
               f"Creating a checkpoint. To restart run: sudo criu restore -D {str(path.absolute())} --shell-job"
           )
