Source code for tfc.utils.TFCUtils

import sys
from colorama import init as initColorama
from colorama import Fore as fg
from colorama import Style as style

from collections import OrderedDict
from functools import partial

from jax._src.config import config

config.update("jax_enable_x64", True)
import numpy as onp
import numpy.typing as npt
import jax.numpy as np
from jax import jvp, jit, lax, jacfwd, typeof
from jax.extend import linear_util as lu
from jax.api_util import debug_info
from jax.tree_util import register_pytree_node, tree_map
from jax._src.api_util import flatten_fun
from jax._src.tree_util import tree_flatten
from jax.core import eval_jaxpr
from jax.interpreters.partial_eval import trace_to_jaxpr_nounits, PartialVal
from jax.experimental import io_callback
from typing import Any, Callable, Optional, cast, Union, overload
from .tfc_types import uint, Literal, TypedDict, Path
from jaxtyping import PyTree

# Types that can be added to a TFCDict
TFCDictAddable = Union[np.ndarray, dict[Any, Any], "TFCDict"]

# Types that can be added to a TFCDictRobust
TFCDictRobustAddable = Union[np.ndarray, dict[Any, Any], "TFCDictRobust"]


[docs] class TFCPrint: """ This class is used to print to the terminal in color. """ def __init__(self): """ This function is the constructor. It initializes the colorama class. """ initColorama()
[docs] @staticmethod def Error(stringIn: str): """ This function prints errors. It prints the text in 'stringIn' in bright red and exits the program. Parameters ---------- stringIn : str error string """ print(fg.RED + style.BRIGHT + stringIn) print(style.RESET_ALL, end="") sys.exit()
[docs] @staticmethod def Warning(stringIn: str): """ This function prints warnings. It prints the text in 'stringIn' in bright yellow. Parameters ---------- stringIn : str warning string """ print(fg.YELLOW + style.BRIGHT + stringIn) print(style.RESET_ALL, end="")
[docs] def egrad(g: Callable[..., Any], j: uint = 0): """ This function mimics egrad from autograd. Parameters ---------- g : Callable[..., Any] Function to take the derivative of. j : uint, optional Parameter with which to take the derivative with respect to. (Default value = 0) Returns ------- wrapped : function Derivative function """ def wrapped(*args: Any) -> Any: """ Wrapper for derivative of g with respect to parameter number j. Parameters ---------- *args : Any function arguments to g Returns ------- x_bar: Any derivative of g with respect to parameter number j """ tans = tuple( [ onp.ones(args[i].shape) if i == j else onp.zeros(args[i].shape) for i in range(len(args)) ] ) _, x_bar = jvp(g, args, tans) return x_bar return wrapped
@partial(partial, tree_map) def onesRobust(val: PyTree): """ Returns ones_like val, but can handle arrays and dictionaries. Parameters ---------- val : PyTree Returns ------- ones_like_val : PyTree Pytree with the same structure as val with all elements equal to one. """ # The @partial will force the PyTree to look array-like so we can use array # functions as done below. Using an explicit cast here so LSPs will not complain. val = cast(npt.NDArray, val) return onp.ones(val.shape, dtype=val.dtype) @partial(partial, tree_map) def zerosRobust(val: PyTree): """ Returns zeros_like val, but can handle arrays and dictionaries. Parameters ---------- val : PyTree Returns ------- zeros_like_val : PyTree Pytree with the same structure as val with all elements equal to zero. """ # The @partial will force the PyTree to look array-like so we can use array # functions as done below. Using an explicit cast here so LSPs will not complain. val = cast(npt.NDArray, val) return onp.zeros(val.shape, dtype=val.dtype)
[docs] def egradRobust(g: Callable[..., Any], j: uint = 0): """This function mimics egrad from autograd, but can also handle dictionaries. Parameters ---------- g : function Function to take the derivative of. j : integer, optional Parameter with which to take the derivative with respect to. (Default value = 0) Returns ------- wrapped : function Derivative function """ if g.__qualname__ == "jit.<locals>.f_jitted": g = g.__wrapped__ def wrapped(*args: Any) -> Any: """ Wrapper for derivative of g with respect to parameter number j. Parameters ---------- *args : iterable function arguments to g Returns ------- x_bar: array-like derivative of g with respect to parameter number j """ tans = tuple( [onesRobust(args[i]) if i == j else zerosRobust(args[i]) for i in range(len(args))] ) _, x_bar = jvp(g, args, tans) return x_bar return wrapped
[docs] def pe(*args: Any, constant_arg_nums: list[int] = []) -> Any: """ Decorator that returns a function evaluated such that the arg numbers specified in constant_arg_nums and all functions that utilizes only those arguments are treated as compile time constants. Parameters ---------- *args : Any Arguments for the function that pe is applied to. constant_arg_nums : list[int], optional The arguments whose values and functions that depend only on these values should be treated as cached constants. Returns ------- f : Any The new function whose constant_arg_num arguments have been removed. The jaxpr of this function has the constant_arg_num values and all functions that depend on those values cached as constants. Usage ----- @pe(*args, constant_arg_nums=[0]) def f(x,xi): # Function stuff here # Returns an f(xi) with x treated as constant """ # Reorder to put knowns first, then unknowns order = [k for k in range(len(args))] for k in constant_arg_nums: order.insert(0, order.pop(k)) reorder = np.argsort(np.array(order)) dark = tuple(args[k] for k in order) # Store the removed args for later num_args_remove = len(constant_arg_nums) def wrapper(f_orig): if len(constant_arg_nums) > 0: # Reordering args so the ones to remove are given first # This will allow us to return a function that has completely removed those args # Moreover, we do it here so this reordering will be optimized by the compiler def f(*args): new_args = tuple(args[k] for k in reorder) return f_orig(*new_args) # Create the partial args needed by trace_to_jaxpr_nounits def get_arg(a, unknown): if unknown: return tree_flatten( ( tree_map(lambda x: PartialVal.unknown(typeof(x).at_least_vspace()), a), {}, ) )[0] else: return PartialVal.known(a) part_args = [] for k, a in enumerate(dark): temp = get_arg(a, k >= num_args_remove) if isinstance(temp, list): part_args += temp else: part_args.append(temp) part_args = tuple(part_args) # Create jaxpr wrap = lu.wrap_init(f, debug_info=debug_info(f.__name__, f, [], {})) _, in_tree = tree_flatten((dark, {})) wrap_flat, out_tree = flatten_fun(wrap, in_tree) jaxpr, _, const = trace_to_jaxpr_nounits(wrap_flat, part_args) # Create new, partially evaluated function if out_tree().num_leaves == 1 and out_tree().num_nodes == 1: # out_tree() is PyTreeDef(*), so just return the value. Since eval_jaxpr returns a list, # this is just value [0] f_removed = lambda *args: eval_jaxpr(jaxpr, const, *tree_flatten((*args, {}))[0])[0] else: # Use out_tree() to reshape the args correctly. f_removed = lambda *args: out_tree().unflatten( eval_jaxpr(jaxpr, const, *tree_flatten((*args, {}))[0]) ) return f_removed else: return f_orig return wrapper
[docs] def pejit(*args: Any, constant_arg_nums: list[int] = [], **kwargs) -> Any: """ Works like :func:`pe <tfc.utils.TFCUtils.pe>`, but also JITs the returned function. See :func:`pe <tfc.utils.TFCUtils.pe>` for more details. Parameters: ----------- *args: Any Arguments for the function that pe is applied to. constant_arg_nums: list[int], optional The arguments whose values and functions that depend only on these values should be treated as cached constants. **kwargs: Any Keyword arguments passed on to JIT. Returns: -------- f: Any The new function whose constant_arg_num arguments have been removed. The jaxpr of this function has the constant_arg_num values and all functions that depend on those values cached, and they are treated as compile time constants. """ def wrap(f_orig): return jit(pe(*args, constant_arg_nums=constant_arg_nums)(f_orig), **kwargs) return wrap
[docs] class TFCDict(OrderedDict): """ This is the TFC dictionary class. It extends an OrderedDict and adds a few methods that enable: - Adding dictionaries with the same keys together - Turning a dictionary into a 1-D array - Turning a 1-D array into a dictionary """ def __init__(self, *args: Any): """ Initialize TFCDict using the OrderedDict method. Parameters ---------- *args : Any Arguments to pass to the OrderedDict """ # Store dictionary and keep a record of the keys. Keys will stay in same # order, so that adding and subtracting is repeatable. super().__init__(*args) self._keys = list(self.keys()) self._nKeys = len(self._keys) self.getSlices()
[docs] def getSlices(self) -> None: """ Function that creates slices for each of the keys in the dictionary. """ if all( isinstance(value, np.ndarray) or isinstance(value, onp.ndarray) for value in self.values() ): arrLen = 0 self._slices = [ slice(0, 0, 1), ] * self._nKeys start = 0 stop = 0 for k in range(self._nKeys): start = stop arrLen = self[self._keys[k]].shape[0] stop = start + arrLen self._slices[k] = slice(start, stop, 1) else: self._slices = [ None, ] * self._nKeys
[docs] def update(self, *args: Any) -> None: """ Overload the update method to update the _keys variable as well. Parameters ---------- *args : Any Same as *args for the update method of ordered dict. """ super().update(*args) self._keys = list(self.keys()) self._nKeys = len(self._keys) self.getSlices()
[docs] def toArray(self) -> np.ndarray: """ Send dictionary to a flat JAX array. Returns ------- np.ndarray This dictionary as a flat JAX array. """ return cast(np.ndarray, np.hstack([k for k in self.values()]))
[docs] def toDict(self, arr: np.ndarray) -> "TFCDict": """ Send a flat JAX array to a TFCDict with the same keys. Parameters ---------- arr : ndarray Flat JAX array to convert to TFCDict. Must have the same number of elements as total number of elements in the dictionary. Returns ------- TFCDict JAX array as a TFCDict """ arr = arr.flatten() return TFCDict(zip(self._keys, [arr[self._slices[k]] for k in range(self._nKeys)]))
[docs] def block_until_ready(self) -> "TFCDict": """ Mimics block_until_ready for jax arrays. Used to halt the program until the computation that created the dictionary is finished. Returns ------- TFCDict This TFCDict. """ self[self._keys[0]].block_until_ready() return self
def __iadd__(self, o: TFCDictAddable) -> "TFCDict": """ Used to overload "+=" for TFCDict so that 2 TFCDict's can be added together. Parameters ---------- o : TFCDictAddable Values to add to the current dicitonary. Returns ---------- self : TFCDict A copy of self after adding in the values from o. """ if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: self[key] += o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): self[self._keys[k]] += o[self._slices[k]] return self def __isub__(self, o: TFCDictAddable) -> "TFCDict": """ Used to overload "-=" for TFCDict so that 2 TFCDict's can be subtracted. Parameters ---------- o : TFCDictAddable Values to subtract from the current dicitonary. Returns ------- self : TFCDict A copy of self after subtracting the values from o. """ if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: self[key] -= o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): self[self._keys[k]] -= o[self._slices[k]] return self def __add__(self, o: TFCDictAddable) -> "TFCDict": """ Used to overload "+" for TFCDict so that 2 TFCDict's can be added together. Parameters ---------- o : TFCDictAddable Values to add to the current dicitonary. Returns ---------- out : TFCDict A TFCDict with values = self + o. """ out = TFCDict(self) if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: out[key] += o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): out[self._keys[k]] += o[self._slices[k]] return out def __sub__(self, o: TFCDictAddable) -> "TFCDict": """ Used to overload "-" for TFCDict so that 2 TFCDict's can be subtracted. Parameters ---------- o : TFCDictAddable Values to subtract from the current dicitonary. Returns ---------- self : TFCDict A TFCDict with values = self - o. """ out = TFCDict(self) if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: out[key] -= o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): out[self._keys[k]] -= o[self._slices[k]] return out
# Register TFCDict as a JAX type register_pytree_node( TFCDict, lambda x: (list(x.values()), list(x.keys())), lambda keys, values: TFCDict(zip(keys, values)), )
[docs] class TFCDictRobust(OrderedDict): """This class is like the :class:`TFCDict <tfc.utils.TFCUtils.TFCDict>` class, but it handles non-flat arrays.""" def __init__(self, *args: Any): """ Initialize TFCDictRobust using the OrderedDict method. Parameters ---------- *args : Any Arguments to pass to the OrderedDict """ # Store dictionary and keep a record of the keys. Keys will stay in same # order, so that adding and subtracting is repeatable. super().__init__(*args) self._keys = list(self.keys()) self._nKeys = len(self._keys) self.getSlices()
[docs] def getSlices(self) -> None: """ Function that creates slices for each of the keys in the dictionary. """ if all(isinstance(value, np.ndarray) for value in self.values()): arrLen = 0 self._slices = [ slice(0, 0, 1), ] * self._nKeys start = 0 stop = 0 for k in range(self._nKeys): start = stop arrLen = self[self._keys[k]].flatten().shape[0] stop = start + arrLen self._slices[k] = slice(start, stop, 1) else: self._slices = [ None, ] * self._nKeys
[docs] def update(self, *args: Any) -> None: """ Overload the update method to update the _keys variable as well. Parameters ---------- *args : Any Same as *args for the update method on ordered dict. """ super().update(*args) self._keys = list(self.keys()) self._nKeys = len(self._keys) self.getSlices()
[docs] def toArray(self) -> np.ndarray: """ Send dictionary to a flat JAX array. Returns ------- np.ndarray This dictionary as a flat JAX array. """ return cast(np.ndarray, np.hstack([k.flatten() for k in self.values()]))
[docs] def toDict(self, arr: np.ndarray) -> "TFCDictRobust": """ Send a flat JAX array to a TFCDictRobust with the same keys. Parameters ---------- arr : np.ndarray Flat JAX array to convert to TFCDictRobust. Must have the same number of elements as total number of elements in the dictionary. Returns ------- TFCDictRobust JAX array as a TFCDictRobust """ arr = arr.flatten() return TFCDictRobust( zip( self._keys, [ arr[self._slices[k]].reshape(self[self._keys[k]].shape) for k in range(self._nKeys) ], ) )
[docs] def block_until_ready(self) -> "TFCDictRobust": """ Mimics block_until_ready for jax arrays. Used to halt the program until the computation that created the dictionary is finished. Returns ------- TFCDictRobust This TFCDictRobust """ self[self._keys[0]].block_until_ready() return self
def __iadd__(self, o: TFCDictRobustAddable) -> "TFCDictRobust": """ Used to overload "+=" for TFCDictRobust so that 2 TFCDictRobust's can be added together. Parameters ---------- o : TFCDictRobustAddable Values to add to the current dicitonary. Returns ---------- self : TFCDictRobust A copy of self after adding in the values from o. """ if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: self[key] += o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): self[self._keys[k]] += o[self._slices[k]].reshape(self[self._keys[k]].shape) return self def __isub__(self, o: TFCDictRobustAddable) -> "TFCDictRobust": """ Used to overload "-=" for TFCDictRobust so that 2 TFCDictRobust's can be subtracted. Parameters ---------- o : TFCDictRobustAddable Values to subtract from the current dicitonary. Returns ---------- self : TFCDictRobust A copy of self after subtracting the values from o. """ if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: self[key] -= o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): self[self._keys[k]] -= o[self._slices[k]].reshape(self[self._keys[k]].shape) return self def __add__(self, o: TFCDictRobustAddable) -> "TFCDictRobust": """ Used to overload "+" for TFCDictRobust so that 2 TFCDictRobust's can be added together. Parameters ---------- o : TFCDictRobustAddable Values to add to the current dicitonary. Returns ---------- out : TFCDictRobust A TFCDictRobust with values = self + o. """ out = TFCDictRobust(self) if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: out[key] += o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): out[self._keys[k]] += o[self._slices[k]].reshape(self[self._keys[k]].shape) return out def __sub__(self, o: TFCDictRobustAddable) -> "TFCDictRobust": """ Used to overload "-" for TFCDictRobust so that 2 TFCDictRobust's can be subtracted. Parameters ---------- o : TFCDictRobustAddable Values to subtract from the current dicitonary. Returns ------- self : TFCDictRobust A TFCDictRobust with values = self - o. """ out = TFCDictRobust(self) if isinstance(o, dict) or (type(o) is type(self)): for key in self._keys: out[key] -= o[key] elif isinstance(o, np.ndarray): o = o.flatten() for k in range(self._nKeys): out[self._keys[k]] -= o[self._slices[k]].reshape(self[self._keys[k]].shape) return out
# Register TFCDictRobust as a JAX type register_pytree_node( TFCDictRobust, lambda x: (list(x.values()), list(x.keys())), lambda keys, values: TFCDictRobust(zip(keys, values)), ) @overload def LS( zXi: PyTree, res: Callable, *args: Any, constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, method: Literal["pinv", "lstsq"] = "pinv", timer: Literal[False] = False, timerType: str = "process_time", holomorphic: bool = False, ) -> PyTree: ... @overload def LS( zXi: PyTree, res: Callable, *args: Any, constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, method: Literal["pinv", "lstsq"] = "pinv", timer: Literal[True] = True, timerType: str = "process_time", holomorphic: bool = False, ) -> tuple[PyTree, float]: ...
[docs] def LS( zXi: PyTree, res: Callable, *args: Any, constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, method: Literal["pinv", "lstsq"] = "pinv", timer: bool = False, timerType: str = "process_time", holomorphic: bool = False, ) -> PyTree | tuple[PyTree, float]: """ JITed least squares. This function takes in an initial guess of zeros, zXi, and a residual function, res, and linear least squares to minimize the res function using the parameters xi. Parameters ---------- zXi : PyTree Unknown parameters to be found using least-squares. res : Callable Residual function (also known as the loss function) with signature res(xi: PyTree, *args:Any, **kwargs:Any). Note, the first argument does not need to be named xi, this is just illustrative. *args : Any Any additional arguments taken by res other than the first PyTree argument. constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details. J : Optional[Callable[...,np.ndarray]] User specified Jacobian function. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None) method : Literal["pinv","lstsq"], optional Method for least-squares inversion. (Default value = "pinv") * pinv - Use np.linalg.pinv * lstsq - Use np.linalg.lstsq timer : bool, optional Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime. As one iteration of the non-linear least squares is run first to avoid timining the JAX trace. timerType : str, optional Any timer from the time module. (Default value = "process_time") holomorphic : bool, optional Indicates whether residual function is promised to be holomorphic. (Default value = False) Returns ------- xi : pytree or array-like Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input. time : float Computation time as calculated by timerType specified. This output is only returned if timer = True. """ if isinstance(zXi, TFCDict) or isinstance(zXi, TFCDictRobust): dictFlag = True else: dictFlag = False if J is None: if dictFlag: if isinstance(zXi, TFCDictRobust): def J(xi, *args): jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args) return np.hstack( [ jacob[k].reshape(jacob[k].shape[0], onp.prod(onp.array(xi[k].shape))) for k in xi.keys() ] ) else: def J(xi, *args): jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args) return np.hstack([jacob[k] for k in xi.keys()]) else: J = lambda xi, *args: jacfwd(res, 0, holomorphic=holomorphic)(xi, *args) if method == "pinv": ls = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), -res(xi, *args)) elif method == "lstsq": ls = lambda xi, *args: np.linalg.lstsq(J(xi, *args), -res(xi, *args), rcond=None)[0] else: TFCPrint.Error("The method entered is not valid. Please enter a valid method.") if constant_arg_nums: # Make arguments constant if desired ls = pe(zXi, *args, constant_arg_nums=constant_arg_nums)(ls) args: list[Any] = list(args) constant_arg_nums.sort() constant_arg_nums.reverse() for k in constant_arg_nums: args.pop(k - 1) ls = jit(ls) zXi = zerosRobust(zXi) if timer: import time timer_f: Callable[[], float] = getattr(time, timerType) ls(zXi, *args).block_until_ready() start = timer_f() xi = ls(zXi, *args).block_until_ready() stop = timer_f() zXi += xi return zXi, stop - start else: zXi += ls(zXi, *args) return zXi
[docs] class LsClass: """ JITed linear least-squares class. Like the :func:`LS <tfc.utils.TFCUtils.LS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing. See :func:`LS <tfc.utils.TFCUtils.LS>` for more details. """ def __init__( self, zXi: PyTree, res: Callable, *args: Any, constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, method: Literal["pinv", "lstsq"] = "pinv", timer: bool = False, timerType: str = "process_time", holomorphic: bool = False, ) -> None: """ Initialization function. Creates the JIT-ed least-squares function. Parameters ---------- zXi : PyTree Unknown parameters to be found using least-squares. res : Callable Residual function (also known as the loss function) with signature res(xi: PyTree, *args:Any, **kwargs:Any). Note, the first argument does not need to be named xi, this is just illustrative. *args : Any Any additional arguments taken by res other than the first PyTree argument. J : Optional[Callable[...,np.ndarray]] User specified Jacobian function. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None) constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details. method : Literal["pinv","lstsq"], optional Method for least-squares inversion. (Default value = "pinv") * pinv - Use np.linalg.pinv * lstsq - Use np.linalg.lstsq timer : bool, optional Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime. As one iteration of the non-linear least squares is run first to avoid timining the JAX trace. timerType : str, optional Any timer from the time module. (Default value = "process_time") holomorphic : bool, optional Indicates whether residual function is promised to be holomorphic. (Default value = False) """ self.timerType = timerType self.timer = timer self.holomorphic = holomorphic if isinstance(zXi, TFCDict) or isinstance(zXi, TFCDictRobust): dictFlag = True else: dictFlag = False if J is None: if dictFlag: if isinstance(zXi, TFCDictRobust): def J(xi, *args): jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args) return np.hstack( [ jacob[k].reshape( jacob[k].shape[0], onp.prod(onp.array(xi[k].shape)) ) for k in xi.keys() ] ) else: def J(xi, *args): jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args) return np.hstack([jacob[k] for k in xi.keys()]) else: J = lambda xi, *args: jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args) if method == "pinv": ls = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), -res(xi, *args)) elif method == "lstsq": ls = lambda xi, *args: np.linalg.lstsq(J(xi, *args), -res(xi, *args), rcond=None)[0] else: TFCPrint.Error("The method entered is not valid. Please enter a valid method.") if constant_arg_nums: # Make arguments constant if desired ls = pe(zXi, *args, constant_arg_nums=constant_arg_nums)(ls) args: list[Any] = list(args) constant_arg_nums.sort() constant_arg_nums.reverse() for k in constant_arg_nums: args.pop(k - 1) self._ls = jit(ls) self._compiled = False
[docs] def run(self, zXi: PyTree, *args: Any) -> PyTree | tuple[PyTree, float]: """ Runs the JIT-ed least-squares function and times it if desired. Parameters ---------- zXi : PyTree Unknown parameters to be found using least-squares. *args : Any Any additional arguments taken by res other than xi. Returns ------- xi : PyTree Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input. time : float, optional Computation time as calculated by timerType specified. This output is only returned if timer = True. """ if self.timer: import time timer = getattr(time, self.timerType) if not self._compiled: self._ls(zXi, *args).block_until_ready() self._compiled = True start = timer() xi = self._ls(zXi, *args).block_until_ready() stop = timer() zXi += xi return zXi, stop - start else: zXi += ls(zXi, *args) self._compiled = True return zXi
[docs] def nlls_id_print(it: int, x, end: str = "\n"): print("Iteration: {0}\tmax(abs(res)): {1}".format(it, x), end=end)
@overload def NLLS( xiInit: PyTree, res: Callable, *args: Any, constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, cond: Optional[Callable[[PyTree], bool]] = None, body: Optional[Callable[[PyTree], PyTree]] = None, tol: float = 1e-13, maxIter: uint = 50, method: Literal["pinv", "lstsq"] = "pinv", timer: Literal[False] = False, printOut: bool = False, printOutEnd: str = "\n", timerType: str = "process_time", holomorphic: bool = False, ) -> tuple[PyTree, int]: ... @overload def NLLS( xiInit: PyTree, res: Callable, *args: Any, constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, cond: Optional[Callable[[PyTree], bool]] = None, body: Optional[Callable[[PyTree], PyTree]] = None, tol: float = 1e-13, maxIter: uint = 50, method: Literal["pinv", "lstsq"] = "pinv", timer: Literal[True] = True, printOut: bool = False, printOutEnd: str = "\n", timerType: str = "process_time", holomorphic: bool = False, ) -> tuple[PyTree, int, float]: ...
[docs] def NLLS( xiInit: PyTree, res: Callable, *args: Any, constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, cond: Optional[Callable[[PyTree], bool]] = None, body: Optional[Callable[[PyTree], PyTree]] = None, tol: float = 1e-13, maxIter: uint = 50, method: Literal["pinv", "lstsq"] = "pinv", timer: bool = False, printOut: bool = False, printOutEnd: str = "\n", timerType: str = "process_time", holomorphic: bool = False, ) -> tuple[PyTree, int] | tuple[PyTree, int, float]: """ JIT-ed non-linear least squares. This function takes in an initial guess, xiInit (initial values of xi), and a residual function, res, and performs a nonlinear least squares to minimize the res function using the parameters xi. The conditions on terminating the nonlinear least-squares are: 1. max(abs(res)) < tol 2. max(abs(dxi)) < tol, where dxi is the change in xi from the last iteration. 3. Number of iterations > maxIter. Parameters ---------- xiInit : pytree or array-like Initial guess for the unkown parameters. res : function Residual function (also known as the loss function) with signature res(xi,*args). *args : iterable Any additional arguments taken by res other than xi. constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details. J : function User specified Jacobian. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None) cond : Optional[Callable[[PyTree],bool]] User specified condition function. If None, then the default cond function is used which checks the three termination criteria provided in the class description. (Default value = None) body : Optional[Callable[[PyTree],PyTree]] User specified while-loop body function. If None, then use the default body function which updates xi using a NLLS interation and the method provided. (Default value = None) tol : float Tolerance used in the default termination criteria: see class description for more details. (Default value = 1e-13) maxIter : int, optional Maximum number of iterations. (Default value = 50) method : Literal["pinv","lstsq"], optional Method for least-squares inversion. (Default value = "pinv") * pinv - Use np.linalg.pinv * lstsq - Use np.linalg.lstsq timer : bool, optional Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime. As one iteration of the non-linear least squares is run first to avoid timining the JAX trace. printOut : bool, optional Controls whether the NLLS prints out information each interaton or not. The printout consists of the iteration and max(abs(res)) at each iteration. (Default value = False) printOutEnd : str, optional Value of keyword argument end passed to the print statement used in printOut. (Default value = "\\\\n") timerType : str, optional Any timer from the time module. (Default value = "process_time") holomorphic : bool, optional Indicates whether residual function is promised to be holomorphic. (Default value = False) Returns ------- xi : PyTree Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input. it : int Number of NLLS iterations performed. time : float Computation time as calculated by timerType specified. This output is only returned if timer = True. """ if timer and printOut: TFCPrint.Warning( "Warning, you have both the timer and printer on in the nonlinear least-squares.\nThe time will be longer than optimal due to the printout." ) if isinstance(xiInit, TFCDict) or isinstance(xiInit, TFCDictRobust): dictFlag = True else: dictFlag = False def cond(val): return np.all( np.array( [ np.max(np.abs(res(val["xi"], *val["args"]))) > tol, val["it"] < maxIter, np.max(np.abs(val["dxi"])) > tol, ] ) ) if J is None: if dictFlag: if isinstance(xiInit, TFCDictRobust): def J(xi, *args): jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args) return np.hstack( [ jacob[k].reshape(jacob[k].shape[0], onp.prod(onp.array(xi[k].shape))) for k in xi.keys() ] ) else: def J(xi, *args): jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args) return np.hstack([jacob[k] for k in xi.keys()]) else: J = lambda xi, *args: jacfwd(res, 0, holomorphic=holomorphic)(xi, *args) if method == "pinv": LS = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), res(xi, *args)) elif method == "lstsq": LS = lambda xi, *args: np.linalg.lstsq(J(xi, *args), res(xi, *args), rcond=None)[0] else: TFCPrint.Error("The method entered is not valid. Please enter a valid method.") if constant_arg_nums: # Make arguments constant if desired LS = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(LS) res = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(res) args: list[Any] = list(args) constant_arg_nums.sort() constant_arg_nums.reverse() for k in constant_arg_nums: args.pop(k - 1) if body is None: if printOut: def body(val): val["dxi"] = LS(val["xi"], *val["args"]) val["xi"] -= val["dxi"] io_callback( partial(nlls_id_print, end=printOutEnd), None, val["it"], np.max(np.abs(res(val["xi"], *val["args"]))), ) val["it"] += 1 return val else: def body(val): val["dxi"] = LS(val["xi"], *val["args"]) val["xi"] -= val["dxi"] val["it"] += 1 return val nlls = jit(lambda val: lax.while_loop(cond, body, val)) if dictFlag: dxi = np.ones_like(cast(TFCDict | TFCDictRobust, xiInit).toArray()) else: dxi = np.ones_like(xiInit) if timer: import time timer_f: Callable[[], float] = getattr(time, timerType) val = {"xi": xiInit, "dxi": dxi, "it": maxIter - 1, "args": args} nlls(val)["dxi"].block_until_ready() val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args} start = timer_f() val = nlls(val) val["dxi"].block_until_ready() stop = timer_f() return val["xi"], val["it"], stop - start else: val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args} val = nlls(val) return val["xi"], val["it"]
[docs] class NllsClass: """ JITed nonlinear least squares class. Like the :func:`NLLS <tfc.utils.TFCUtils.NLLS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing """ def __init__( self, xiInit: PyTree, res: Callable, *args: Any, constant_arg_nums: list[int] = [], J: Optional[Callable[..., np.ndarray]] = None, cond: Optional[Callable[[PyTree], bool]] = None, body: Optional[Callable[[PyTree], PyTree]] = None, tol: float = 1e-13, maxIter: uint = 50, method: Literal["pinv", "lstsq"] = "pinv", timer: bool = False, printOut: bool = False, printOutEnd: str = "\n", timerType: str = "process_time", holomorphic: bool = False, ) -> None: """ Initialization function. Creates the JIT-ed nonlinear least-squares function. Parameters ---------- xiInit : pytree or array-like Initial guess for the unkown parameters. res : function Residual function (also known as the loss function) with signature res(xi,*args). *args : iterable Any additional arguments taken by res other than xi. constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details. J : function User specified Jacobian. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None) cond : Optional[Callable[[PyTree],bool]] User specified condition function. If None, then the default cond function is used which checks the three termination criteria provided in the class description. (Default value = None) body : Optional[Callable[[PyTree],PyTree]] User specified while-loop body function. If None, then use the default body function which updates xi using a NLLS interation and the method provided. (Default value = None) tol : float Tolerance used in the default termination criteria: see class description for more details. (Default value = 1e-13) maxIter : int, optional Maximum number of iterations. (Default value = 50) method : Literal["pinv","lstsq"], optional Method for least-squares inversion. (Default value = "pinv") * pinv - Use np.linalg.pinv * lstsq - Use np.linalg.lstsq timer : bool, optional Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime. As one iteration of the non-linear least squares is run first to avoid timining the JAX trace. printOut : bool, optional Controls whether the NLLS prints out information each interaton or not. The printout consists of the iteration and max(abs(res)) at each iteration. (Default value = False) printOutEnd : str, optional Value of keyword argument end passed to the print statement used in printOut. (Default value = "\\\\n") timerType : str, optional Any timer from the time module. (Default value = "process_time") holomorphic : bool, optional Indicates whether residual function is promised to be holomorphic. (Default value = False) """ self.timerType = timerType self.timer = timer self._maxIter = maxIter self.holomorphic = holomorphic if timer and printOut: TFCPrint.Warning( "Warning, you have both the timer and printer on in the nonlinear least-squares.\nThe time will be longer than optimal due to the printout." ) if isinstance(xiInit, TFCDict) or isinstance(xiInit, TFCDictRobust): self._dictFlag = True else: self._dictFlag = False def cond(val): return np.all( np.array( [ np.max(np.abs(res(val["xi"], *val["args"]))) > tol, val["it"] < maxIter, np.max(np.abs(val["dxi"])) > tol, ] ) ) if J is None: if self._dictFlag: if isinstance(xiInit, TFCDictRobust): def J(xi, *args): jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args) return np.hstack( [ jacob[k].reshape( jacob[k].shape[0], onp.prod(onp.array(xi[k].shape)) ) for k in xi.keys() ] ) else: def J(xi, *args): jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args) return np.hstack([jacob[k] for k in xi.keys()]) else: J = lambda xi, *args: jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args) if method == "pinv": LS = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), res(xi, *args)) elif method == "lstsq": LS = lambda xi, *args: np.linalg.lstsq(J(xi, *args), res(xi, *args), rcond=None)[0] else: TFCPrint.Error("The method entered is not valid. Please enter a valid method.") if constant_arg_nums: # Make arguments constant if desired LS = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(LS) res = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(res) args: list[Any] = list(args) constant_arg_nums.sort() constant_arg_nums.reverse() for k in constant_arg_nums: args.pop(k - 1) if body is None: if printOut: def body(val): val["dxi"] = LS(val["xi"], *val["args"]) val["xi"] -= val["dxi"] io_callback( partial(nlls_id_print, end=printOutEnd), None, val["it"], np.max(np.abs(res(val["xi"], *val["args"]))), ) val["it"] += 1 return val else: def body(val): val["dxi"] = LS(val["xi"], *val["args"]) val["xi"] -= val["dxi"] val["it"] += 1 return val self._nlls = jit(lambda val: lax.while_loop(cond, body, val)) self._compiled = False
[docs] def run(self, xiInit: PyTree, *args: Any) -> tuple[PyTree, int] | tuple[PyTree, int, float]: """Runs the JIT-ed nonlinear least-squares function and times it if desired. Parameters ---------- xiInit : PyTree Initial guess for the unkown parameters. *args : Any Any additional arguments taken by res other than xi. Returns ------- xi : PyTree Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input. it : int Number of NLLS iterations performed.. time : float Computation time as calculated by timerType specified. This output is only returned if timer = True. """ if self._dictFlag: dxi = np.ones_like(cast(TFCDict | TFCDictRobust, xiInit).toArray()) else: dxi = np.ones_like(xiInit) if self.timer: import time timer_f: Callable[[], float] = getattr(time, self.timerType) if not self._compiled: val = {"xi": xiInit, "dxi": dxi, "it": self._maxIter - 1, "args": args} self._nlls(val)["dxi"].block_until_ready() self._compiled = True val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args} start = timer_f() val = self._nlls(val) val["dxi"].block_until_ready() stop = timer_f() return val["xi"], val["it"], stop - start else: val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args} val = self._nlls(val) self._compiled = True return val["xi"], val["it"]
[docs] class ComponentConstraintDict(TypedDict): name: str node0: str node1: str
[docs] class ComponentConstraintGraph: """ Creates a graph of all valid ways in which component constraints can be embedded. """ def __init__(self, N: list[str], E: list[ComponentConstraintDict]) -> None: """ Class constructor. Parameters ---------- N : list[str] A list of strings that specify the node names. These node names typically coincide with the names of the dependent variables. E : list[ComponentConstraintDict] The ComponentConstraintDict is a dictionary with the following fields: * name - Name of the component constraint. * node0 - The name of one of the nodes that makes up the component constraint. Must correspond with an element of the list given in N. * node1 - The name of one of the nodes that makes up the component constraint. Must correspond with an element of the list given in N. """ # Check that all edges are connected to valid nodes self.nNodes = len(N) self.nEdges = len(E) for k in range(self.nEdges): if not (E[k]["node0"] in N and E[k]["node1"] in N): TFCPrint.Error( "Error either " + E[k]["node0"] + " or " + E[k]["node1"] + " is not a valid node. Make sure they appear in the nodes list." ) # Create all possible source/target pairs. This tells whether node0 is the target or source, node1 will be the opposite. import itertools self.targets = list(itertools.product([0, 1], repeat=self.nEdges)) # Find all targets that are valid trees self.goodTargets = [] for j in range(len(self.targets)): adj = onp.zeros((self.nNodes, self.nNodes)) for k in range(self.nNodes): kNode = N[k] for g in range(self.nEdges): if E[g]["node0"] == kNode: if self.targets[j][g]: adj[N.index(E[g]["node1"]), N.index(E[g]["node0"])] = 1.0 elif E[g]["node1"] == kNode: if not self.targets[j][g]: adj[N.index(E[g]["node0"]), N.index(E[g]["node1"])] = 1.0 if np.all(np.linalg.eigvals(adj) == 0.0): self.goodTargets.append(j) # Save nodes and edges for use later self.N = N self.E = E
[docs] def SaveGraphs(self, outputDir: Path, allGraphs: bool = False, savePDFs: bool = False) -> None: """ Saves the graphs. The graphs are saved in a clickable HTML structure. They can also be saved as PDFs. Parameters ---------- outputDir : Path Output directory to save in. allGraphs : bool, optional Boolean that conrols whether all graphs are saved or just valid graphs. (Default value = False) savePDFs : bool, optional Boolean that controls whether the graphs are also saved as PDFs. (Default value = False) """ import os from .Html import HTML, Dot if allGraphs: targets = self.targets else: targets = [self.targets[k] for k in self.goodTargets] n = len(targets) #: Create the main dot file mainDot = Dot(os.path.join(outputDir, "dotFiles", "main"), "main") mainDot.dot.node_attr.update(shape="box") mainDot.dot.edge_attr.update(style="invis") treeCnt = 0 for j in range(int(np.ceil(n / 5))): if j != 0: mainDot.dot.edge("tree" + str((j - 1) * 5), "tree" + str(j * 5)) with mainDot.dot.subgraph(name="subgraph" + str(j)) as c: c.attr(rank="same") for k in range(min(5, n - j * 5)): c.node( "tree" + str(treeCnt), "Tree " + str(treeCnt), href=os.path.join("htmlFiles", "tree" + str(treeCnt) + ".html"), ) treeCnt += 1 mainDot.Render() #: Create the main file HTML mainHtml = HTML(os.path.join(outputDir, "main.html")) with mainHtml.tag("html"): with mainHtml.tag("body"): with mainHtml.tag("style"): mainHtml.doc.asis(mainHtml.centerClass) mainHtml.doc.stag( "img", src=os.path.join("dotFiles", "main.svg"), usemap="#main", klass="center" ) mainHtml.doc.asis( mainHtml.ReadFile(os.path.join(outputDir, "dotFiles", "main.cmapx")) ) mainHtml.WriteFile() #: Create the tree dot files for k in range(n): treeDot = Dot(os.path.join(outputDir, "dotFiles", "tree" + str(k)), "tree" + str(k)) treeDot.dot.attr(bgcolor="transparent") treeDot.dot.node_attr.update(shape="box") for j in range(self.nNodes): treeDot.dot.node(self.N[j], self.N[j]) for j in range(self.nEdges): if not targets[k][j]: treeDot.dot.edge( self.E[j]["node0"], self.E[j]["node1"], label=self.E[j]["name"] ) else: treeDot.dot.edge( self.E[j]["node1"], self.E[j]["node0"], label=self.E[j]["name"] ) if savePDFs: treeDot.Render(formats=["cmapx", "svg", "pdf"]) else: treeDot.Render() #: Create the tree HTML files for k in range(n): treeHtml = HTML(os.path.join(outputDir, "htmlFiles", "tree" + str(k) + ".html")) with treeHtml.tag("html"): with treeHtml.tag("body"): with treeHtml.tag("style"): treeHtml.doc.asis(treeHtml.centerClass) treeHtml.doc.stag( "img", src=os.path.join("..", "dotFiles", "tree" + str(k) + ".svg"), usemap="#tree" + str(k), klass="center", ) treeHtml.doc.asis( treeHtml.ReadFile( os.path.join(outputDir, "dotFiles", "tree" + str(k) + ".cmapx") ) ) treeHtml.WriteFile()
[docs] def ScaledQrLs(A: np.ndarray, B: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """This function performs least-squares using a scaled QR method. Parameters ---------- A : np.ndarray A matrix in A*x = B. B : np.ndarray B matrix in A*x = B. Returns ------- x : np.ndarray Solution to A*x = B solved using a scaled QR method. cn : np.ndarray Condition number. """ S = 1.0 / np.sqrt(np.sum(A * A, 0)) S = np.reshape(S, (A.shape[1],)) q, r = np.linalg.qr(A.dot(np.diag(S))) x: np.ndarray = S * np.linalg.multi_dot([_MatPinv(r), q.T, B]) cn: np.ndarray = cast(np.ndarray, np.linalg.cond(r)) return x, cn
def _MatPinv(A: np.ndarray) -> np.ndarray: """This function is used to better replicate MATLAB's pseudo-inverse. Parameters ---------- A : np.ndarray Matrix to be inverted. Returns ------- Ainv : np.ndarray Inverse of A. """ rcond = onp.max(A.shape) * onp.spacing(np.linalg.norm(A, ord=2)) return np.linalg.pinv(A, rcond=rcond)
[docs] def step(x: np.ndarray) -> np.ndarray: """ This is the unit step function, but the deriative is defined and equal to 0 at every point. Parameters ---------- x : np.ndarray Array to apply step to. Returns ------- step_x : np.ndarray step(x) """ return np.heaviside(x, 0)
[docs] def criuCheckpoint(dir: str = "criu_checkpoint", user_mode: bool = False): """ Use CRIU to create a checkpoint of your program. WARNING: You must ensure no external sockets are used, i.e., via matplotlib. WARNING: GPU memory cannot be mapped for all GPUs. WARNING: user_mode is a work in progress and is not full featured yet. Parameters ---------- dir : str Directory where the CRIU checkpoint will be created. user_mode : bool Whether to use user mode or not. (Default value = False) """ import os from pathlib import Path path = Path(dir) pid = os.getpid() # Create path if it does not exist if not os.path.exists(path): os.mkdir(path) # Run the command and print how to restart if user_mode: os.system( f"(criu dump --unpriviledged -t {pid} -D {str(path.absolute())} --shell-job --leave-running &) &" ) print( f"Creating a checkpoint. To restart run: sudo criu restore -D {str(path.absolute())} --shell-job" ) else: os.system( f"(sudo criu dump -t {pid} -D {str(path.absolute())} --shell-job --leave-running &) &" ) print( f"Creating a checkpoint. To restart run: sudo criu restore -D {str(path.absolute())} --shell-job" )