Program Listing for File TFCUtils.py#
↰ Return to documentation for file (src/tfc/utils/TFCUtils.py)
import sys
from colorama import init as initColorama
from colorama import Fore as fg
from colorama import Style as style
from collections import OrderedDict
from functools import partial
from jax._src.config import config
config.update("jax_enable_x64", True)
import numpy as onp
import numpy.typing as npt
import jax.numpy as np
from jax import jvp, jit, lax, jacfwd, typeof
from jax.extend import linear_util as lu
from jax.api_util import debug_info
from jax.tree_util import register_pytree_node, tree_map
from jax._src.api_util import flatten_fun
from jax._src.tree_util import tree_flatten
from jax.core import eval_jaxpr
from jax.interpreters.partial_eval import trace_to_jaxpr_nounits, PartialVal
from jax.experimental import io_callback
from typing import Any, Callable, Optional, cast, Union, overload
from .tfc_types import uint, Literal, TypedDict, Path
from jaxtyping import PyTree
# Types that can be added to a TFCDict
TFCDictAddable = Union[np.ndarray, dict[Any, Any], "TFCDict"]
# Types that can be added to a TFCDictRobust
TFCDictRobustAddable = Union[np.ndarray, dict[Any, Any], "TFCDictRobust"]
class TFCPrint:
"""
This class is used to print to the terminal in color.
"""
def __init__(self):
"""
This function is the constructor. It initializes the colorama class.
"""
initColorama()
@staticmethod
def Error(stringIn: str):
"""
This function prints errors. It prints the text in 'stringIn' in bright red and
exits the program.
Parameters
----------
stringIn : str
error string
"""
print(fg.RED + style.BRIGHT + stringIn)
print(style.RESET_ALL, end="")
sys.exit()
@staticmethod
def Warning(stringIn: str):
"""
This function prints warnings. It prints the text in 'stringIn' in bright yellow.
Parameters
----------
stringIn : str
warning string
"""
print(fg.YELLOW + style.BRIGHT + stringIn)
print(style.RESET_ALL, end="")
def egrad(g: Callable[..., Any], j: uint = 0):
"""
This function mimics egrad from autograd.
Parameters
----------
g : Callable[..., Any]
Function to take the derivative of.
j : uint, optional
Parameter with which to take the derivative with respect to. (Default value = 0)
Returns
-------
wrapped : function
Derivative function
"""
def wrapped(*args: Any) -> Any:
"""
Wrapper for derivative of g with respect to parameter number j.
Parameters
----------
*args : Any
function arguments to g
Returns
-------
x_bar: Any
derivative of g with respect to parameter number j
"""
tans = tuple(
[
onp.ones(args[i].shape) if i == j else onp.zeros(args[i].shape)
for i in range(len(args))
]
)
_, x_bar = jvp(g, args, tans)
return x_bar
return wrapped
@partial(partial, tree_map)
def onesRobust(val: PyTree):
"""
Returns ones_like val, but can handle arrays and dictionaries.
Parameters
----------
val : PyTree
Returns
-------
ones_like_val : PyTree
Pytree with the same structure as val with all elements equal to one.
"""
# The @partial will force the PyTree to look array-like so we can use array
# functions as done below. Using an explicit cast here so LSPs will not complain.
val = cast(npt.NDArray, val)
return onp.ones(val.shape, dtype=val.dtype)
@partial(partial, tree_map)
def zerosRobust(val: PyTree):
"""
Returns zeros_like val, but can handle arrays and dictionaries.
Parameters
----------
val : PyTree
Returns
-------
zeros_like_val : PyTree
Pytree with the same structure as val with all elements equal to zero.
"""
# The @partial will force the PyTree to look array-like so we can use array
# functions as done below. Using an explicit cast here so LSPs will not complain.
val = cast(npt.NDArray, val)
return onp.zeros(val.shape, dtype=val.dtype)
def egradRobust(g: Callable[..., Any], j: uint = 0):
"""This function mimics egrad from autograd, but can also handle dictionaries.
Parameters
----------
g : function
Function to take the derivative of.
j : integer, optional
Parameter with which to take the derivative with respect to. (Default value = 0)
Returns
-------
wrapped : function
Derivative function
"""
if g.__qualname__ == "jit.<locals>.f_jitted":
g = g.__wrapped__
def wrapped(*args: Any) -> Any:
"""
Wrapper for derivative of g with respect to parameter number j.
Parameters
----------
*args : iterable
function arguments to g
Returns
-------
x_bar: array-like
derivative of g with respect to parameter number j
"""
tans = tuple(
[onesRobust(args[i]) if i == j else zerosRobust(args[i]) for i in range(len(args))]
)
_, x_bar = jvp(g, args, tans)
return x_bar
return wrapped
def pe(*args: Any, constant_arg_nums: list[int] = []) -> Any:
"""
Decorator that returns a function evaluated such that the arg numbers specified in constant_arg_nums
and all functions that utilizes only those arguments are treated as compile time constants.
Parameters
----------
*args : Any
Arguments for the function that pe is applied to.
constant_arg_nums : list[int], optional
The arguments whose values and functions that depend only on these values should be
treated as cached constants.
Returns
-------
f : Any
The new function whose constant_arg_num arguments have been removed. The jaxpr of this
function has the constant_arg_num values and all functions that depend on those values
cached as constants.
Usage
-----
@pe(*args, constant_arg_nums=[0])
def f(x,xi):
# Function stuff here
# Returns an f(xi) with x treated as constant
"""
# Reorder to put knowns first, then unknowns
order = [k for k in range(len(args))]
for k in constant_arg_nums:
order.insert(0, order.pop(k))
reorder = np.argsort(np.array(order))
dark = tuple(args[k] for k in order)
# Store the removed args for later
num_args_remove = len(constant_arg_nums)
def wrapper(f_orig):
if len(constant_arg_nums) > 0:
# Reordering args so the ones to remove are given first
# This will allow us to return a function that has completely removed those args
# Moreover, we do it here so this reordering will be optimized by the compiler
def f(*args):
new_args = tuple(args[k] for k in reorder)
return f_orig(*new_args)
# Create the partial args needed by trace_to_jaxpr_nounits
def get_arg(a, unknown):
if unknown:
return tree_flatten(
(
tree_map(lambda x: PartialVal.unknown(typeof(x).at_least_vspace()), a),
{},
)
)[0]
else:
return PartialVal.known(a)
part_args = []
for k, a in enumerate(dark):
temp = get_arg(a, k >= num_args_remove)
if isinstance(temp, list):
part_args += temp
else:
part_args.append(temp)
part_args = tuple(part_args)
# Create jaxpr
wrap = lu.wrap_init(f, debug_info=debug_info(f.__name__, f, [], {}))
_, in_tree = tree_flatten((dark, {}))
wrap_flat, out_tree = flatten_fun(wrap, in_tree)
jaxpr, _, const = trace_to_jaxpr_nounits(wrap_flat, part_args)
# Create new, partially evaluated function
if out_tree().num_leaves == 1 and out_tree().num_nodes == 1:
# out_tree() is PyTreeDef(*), so just return the value. Since eval_jaxpr returns a list,
# this is just value [0]
f_removed = lambda *args: eval_jaxpr(jaxpr, const, *tree_flatten((*args, {}))[0])[0]
else:
# Use out_tree() to reshape the args correctly.
f_removed = lambda *args: out_tree().unflatten(
eval_jaxpr(jaxpr, const, *tree_flatten((*args, {}))[0])
)
return f_removed
else:
return f_orig
return wrapper
def pejit(*args: Any, constant_arg_nums: list[int] = [], **kwargs) -> Any:
"""
Works like :func:`pe <tfc.utils.TFCUtils.pe>`, but also JITs the returned function. See :func:`pe <tfc.utils.TFCUtils.pe>` for more details.
Parameters:
-----------
*args: Any
Arguments for the function that pe is applied to.
constant_arg_nums: list[int], optional
The arguments whose values and functions that depend only on these values should be
treated as cached constants.
**kwargs: Any
Keyword arguments passed on to JIT.
Returns:
--------
f: Any
The new function whose constant_arg_num arguments have been removed. The jaxpr of this
function has the constant_arg_num values and all functions that depend on those values
cached, and they are treated as compile time constants.
"""
def wrap(f_orig):
return jit(pe(*args, constant_arg_nums=constant_arg_nums)(f_orig), **kwargs)
return wrap
class TFCDict(OrderedDict):
"""
This is the TFC dictionary class. It extends an OrderedDict and
adds a few methods that enable:
- Adding dictionaries with the same keys together
- Turning a dictionary into a 1-D array
- Turning a 1-D array into a dictionary
"""
def __init__(self, *args: Any):
"""
Initialize TFCDict using the OrderedDict method.
Parameters
----------
*args : Any
Arguments to pass to the OrderedDict
"""
# Store dictionary and keep a record of the keys. Keys will stay in same
# order, so that adding and subtracting is repeatable.
super().__init__(*args)
self._keys = list(self.keys())
self._nKeys = len(self._keys)
self.getSlices()
def getSlices(self) -> None:
"""
Function that creates slices for each of the keys in the dictionary.
"""
if all(
isinstance(value, np.ndarray) or isinstance(value, onp.ndarray)
for value in self.values()
):
arrLen = 0
self._slices = [
slice(0, 0, 1),
] * self._nKeys
start = 0
stop = 0
for k in range(self._nKeys):
start = stop
arrLen = self[self._keys[k]].shape[0]
stop = start + arrLen
self._slices[k] = slice(start, stop, 1)
else:
self._slices = [
None,
] * self._nKeys
def update(self, *args: Any) -> None:
"""
Overload the update method to update the _keys variable as well.
Parameters
----------
*args : Any
Same as *args for the update method of ordered dict.
"""
super().update(*args)
self._keys = list(self.keys())
self._nKeys = len(self._keys)
self.getSlices()
def toArray(self) -> np.ndarray:
"""
Send dictionary to a flat JAX array.
Returns
-------
np.ndarray
This dictionary as a flat JAX array.
"""
return cast(np.ndarray, np.hstack([k for k in self.values()]))
def toDict(self, arr: np.ndarray) -> "TFCDict":
"""
Send a flat JAX array to a TFCDict with the same keys.
Parameters
----------
arr : ndarray
Flat JAX array to convert to TFCDict. Must have the same number of elements as total number of elements in the dictionary.
Returns
-------
TFCDict
JAX array as a TFCDict
"""
arr = arr.flatten()
return TFCDict(zip(self._keys, [arr[self._slices[k]] for k in range(self._nKeys)]))
def block_until_ready(self) -> "TFCDict":
"""
Mimics block_until_ready for jax arrays. Used to halt the program until the computation that created the
dictionary is finished.
Returns
-------
TFCDict
This TFCDict.
"""
self[self._keys[0]].block_until_ready()
return self
def __iadd__(self, o: TFCDictAddable) -> "TFCDict":
"""
Used to overload "+=" for TFCDict so that 2 TFCDict's can be added together.
Parameters
----------
o : TFCDictAddable
Values to add to the current dicitonary.
Returns
----------
self : TFCDict
A copy of self after adding in the values from o.
"""
if isinstance(o, dict) or (type(o) is type(self)):
for key in self._keys:
self[key] += o[key]
elif isinstance(o, np.ndarray):
o = o.flatten()
for k in range(self._nKeys):
self[self._keys[k]] += o[self._slices[k]]
return self
def __isub__(self, o: TFCDictAddable) -> "TFCDict":
"""
Used to overload "-=" for TFCDict so that 2 TFCDict's can be subtracted.
Parameters
----------
o : TFCDictAddable
Values to subtract from the current dicitonary.
Returns
-------
self : TFCDict
A copy of self after subtracting the values from o.
"""
if isinstance(o, dict) or (type(o) is type(self)):
for key in self._keys:
self[key] -= o[key]
elif isinstance(o, np.ndarray):
o = o.flatten()
for k in range(self._nKeys):
self[self._keys[k]] -= o[self._slices[k]]
return self
def __add__(self, o: TFCDictAddable) -> "TFCDict":
"""
Used to overload "+" for TFCDict so that 2 TFCDict's can be added together.
Parameters
----------
o : TFCDictAddable
Values to add to the current dicitonary.
Returns
----------
out : TFCDict
A TFCDict with values = self + o.
"""
out = TFCDict(self)
if isinstance(o, dict) or (type(o) is type(self)):
for key in self._keys:
out[key] += o[key]
elif isinstance(o, np.ndarray):
o = o.flatten()
for k in range(self._nKeys):
out[self._keys[k]] += o[self._slices[k]]
return out
def __sub__(self, o: TFCDictAddable) -> "TFCDict":
"""
Used to overload "-" for TFCDict so that 2 TFCDict's can be subtracted.
Parameters
----------
o : TFCDictAddable
Values to subtract from the current dicitonary.
Returns
----------
self : TFCDict
A TFCDict with values = self - o.
"""
out = TFCDict(self)
if isinstance(o, dict) or (type(o) is type(self)):
for key in self._keys:
out[key] -= o[key]
elif isinstance(o, np.ndarray):
o = o.flatten()
for k in range(self._nKeys):
out[self._keys[k]] -= o[self._slices[k]]
return out
# Register TFCDict as a JAX type
register_pytree_node(
TFCDict,
lambda x: (list(x.values()), list(x.keys())),
lambda keys, values: TFCDict(zip(keys, values)),
)
class TFCDictRobust(OrderedDict):
"""This class is like the :class:`TFCDict <tfc.utils.TFCUtils.TFCDict>` class, but it handles non-flat arrays."""
def __init__(self, *args: Any):
"""
Initialize TFCDictRobust using the OrderedDict method.
Parameters
----------
*args : Any
Arguments to pass to the OrderedDict
"""
# Store dictionary and keep a record of the keys. Keys will stay in same
# order, so that adding and subtracting is repeatable.
super().__init__(*args)
self._keys = list(self.keys())
self._nKeys = len(self._keys)
self.getSlices()
def getSlices(self) -> None:
"""
Function that creates slices for each of the keys in the dictionary.
"""
if all(isinstance(value, np.ndarray) for value in self.values()):
arrLen = 0
self._slices = [
slice(0, 0, 1),
] * self._nKeys
start = 0
stop = 0
for k in range(self._nKeys):
start = stop
arrLen = self[self._keys[k]].flatten().shape[0]
stop = start + arrLen
self._slices[k] = slice(start, stop, 1)
else:
self._slices = [
None,
] * self._nKeys
def update(self, *args: Any) -> None:
"""
Overload the update method to update the _keys variable as well.
Parameters
----------
*args : Any
Same as *args for the update method on ordered dict.
"""
super().update(*args)
self._keys = list(self.keys())
self._nKeys = len(self._keys)
self.getSlices()
def toArray(self) -> np.ndarray:
"""
Send dictionary to a flat JAX array.
Returns
-------
np.ndarray
This dictionary as a flat JAX array.
"""
return cast(np.ndarray, np.hstack([k.flatten() for k in self.values()]))
def toDict(self, arr: np.ndarray) -> "TFCDictRobust":
"""
Send a flat JAX array to a TFCDictRobust with the same keys.
Parameters
----------
arr : np.ndarray
Flat JAX array to convert to TFCDictRobust. Must have the same number of elements as total number of elements in the dictionary.
Returns
-------
TFCDictRobust
JAX array as a TFCDictRobust
"""
arr = arr.flatten()
return TFCDictRobust(
zip(
self._keys,
[
arr[self._slices[k]].reshape(self[self._keys[k]].shape)
for k in range(self._nKeys)
],
)
)
def block_until_ready(self) -> "TFCDictRobust":
"""
Mimics block_until_ready for jax arrays. Used to halt the program until the computation that created the
dictionary is finished.
Returns
-------
TFCDictRobust
This TFCDictRobust
"""
self[self._keys[0]].block_until_ready()
return self
def __iadd__(self, o: TFCDictRobustAddable) -> "TFCDictRobust":
"""
Used to overload "+=" for TFCDictRobust so that 2 TFCDictRobust's can be added together.
Parameters
----------
o : TFCDictRobustAddable
Values to add to the current dicitonary.
Returns
----------
self : TFCDictRobust
A copy of self after adding in the values from o.
"""
if isinstance(o, dict) or (type(o) is type(self)):
for key in self._keys:
self[key] += o[key]
elif isinstance(o, np.ndarray):
o = o.flatten()
for k in range(self._nKeys):
self[self._keys[k]] += o[self._slices[k]].reshape(self[self._keys[k]].shape)
return self
def __isub__(self, o: TFCDictRobustAddable) -> "TFCDictRobust":
"""
Used to overload "-=" for TFCDictRobust so that 2 TFCDictRobust's can be subtracted.
Parameters
----------
o : TFCDictRobustAddable
Values to subtract from the current dicitonary.
Returns
----------
self : TFCDictRobust
A copy of self after subtracting the values from o.
"""
if isinstance(o, dict) or (type(o) is type(self)):
for key in self._keys:
self[key] -= o[key]
elif isinstance(o, np.ndarray):
o = o.flatten()
for k in range(self._nKeys):
self[self._keys[k]] -= o[self._slices[k]].reshape(self[self._keys[k]].shape)
return self
def __add__(self, o: TFCDictRobustAddable) -> "TFCDictRobust":
"""
Used to overload "+" for TFCDictRobust so that 2 TFCDictRobust's can be added together.
Parameters
----------
o : TFCDictRobustAddable
Values to add to the current dicitonary.
Returns
----------
out : TFCDictRobust
A TFCDictRobust with values = self + o.
"""
out = TFCDictRobust(self)
if isinstance(o, dict) or (type(o) is type(self)):
for key in self._keys:
out[key] += o[key]
elif isinstance(o, np.ndarray):
o = o.flatten()
for k in range(self._nKeys):
out[self._keys[k]] += o[self._slices[k]].reshape(self[self._keys[k]].shape)
return out
def __sub__(self, o: TFCDictRobustAddable) -> "TFCDictRobust":
"""
Used to overload "-" for TFCDictRobust so that 2 TFCDictRobust's can be subtracted.
Parameters
----------
o : TFCDictRobustAddable
Values to subtract from the current dicitonary.
Returns
-------
self : TFCDictRobust
A TFCDictRobust with values = self - o.
"""
out = TFCDictRobust(self)
if isinstance(o, dict) or (type(o) is type(self)):
for key in self._keys:
out[key] -= o[key]
elif isinstance(o, np.ndarray):
o = o.flatten()
for k in range(self._nKeys):
out[self._keys[k]] -= o[self._slices[k]].reshape(self[self._keys[k]].shape)
return out
# Register TFCDictRobust as a JAX type
register_pytree_node(
TFCDictRobust,
lambda x: (list(x.values()), list(x.keys())),
lambda keys, values: TFCDictRobust(zip(keys, values)),
)
@overload
def LS(
zXi: PyTree,
res: Callable,
*args: Any,
constant_arg_nums: list[int] = [],
J: Optional[Callable[..., np.ndarray]] = None,
method: Literal["pinv", "lstsq"] = "pinv",
timer: Literal[False] = False,
timerType: str = "process_time",
holomorphic: bool = False,
) -> PyTree: ...
@overload
def LS(
zXi: PyTree,
res: Callable,
*args: Any,
constant_arg_nums: list[int] = [],
J: Optional[Callable[..., np.ndarray]] = None,
method: Literal["pinv", "lstsq"] = "pinv",
timer: Literal[True] = True,
timerType: str = "process_time",
holomorphic: bool = False,
) -> tuple[PyTree, float]: ...
def LS(
zXi: PyTree,
res: Callable,
*args: Any,
constant_arg_nums: list[int] = [],
J: Optional[Callable[..., np.ndarray]] = None,
method: Literal["pinv", "lstsq"] = "pinv",
timer: bool = False,
timerType: str = "process_time",
holomorphic: bool = False,
) -> PyTree | tuple[PyTree, float]:
"""
JITed least squares.
This function takes in an initial guess of zeros, zXi, and a residual function, res, and
linear least squares to minimize the res function using the parameters
xi.
Parameters
----------
zXi : PyTree
Unknown parameters to be found using least-squares.
res : Callable
Residual function (also known as the loss function) with signature res(xi: PyTree, *args:Any, **kwargs:Any).
Note, the first argument does not need to be named xi, this is just illustrative.
*args : Any
Any additional arguments taken by res other than the first PyTree argument.
constant_arg_nums: list[int], optional
These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.
J : Optional[Callable[...,np.ndarray]]
User specified Jacobian function. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)
method : Literal["pinv","lstsq"], optional
Method for least-squares inversion. (Default value = "pinv")
* pinv - Use np.linalg.pinv
* lstsq - Use np.linalg.lstsq
timer : bool, optional
Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.
timerType : str, optional
Any timer from the time module. (Default value = "process_time")
holomorphic : bool, optional
Indicates whether residual function is promised to be holomorphic. (Default value = False)
Returns
-------
xi : pytree or array-like
Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.
time : float
Computation time as calculated by timerType specified. This output is only returned if timer = True.
"""
if isinstance(zXi, TFCDict) or isinstance(zXi, TFCDictRobust):
dictFlag = True
else:
dictFlag = False
if J is None:
if dictFlag:
if isinstance(zXi, TFCDictRobust):
def J(xi, *args):
jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
return np.hstack(
[
jacob[k].reshape(jacob[k].shape[0], onp.prod(onp.array(xi[k].shape)))
for k in xi.keys()
]
)
else:
def J(xi, *args):
jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
return np.hstack([jacob[k] for k in xi.keys()])
else:
J = lambda xi, *args: jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
if method == "pinv":
ls = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), -res(xi, *args))
elif method == "lstsq":
ls = lambda xi, *args: np.linalg.lstsq(J(xi, *args), -res(xi, *args), rcond=None)[0]
else:
TFCPrint.Error("The method entered is not valid. Please enter a valid method.")
if constant_arg_nums:
# Make arguments constant if desired
ls = pe(zXi, *args, constant_arg_nums=constant_arg_nums)(ls)
args: list[Any] = list(args)
constant_arg_nums.sort()
constant_arg_nums.reverse()
for k in constant_arg_nums:
args.pop(k - 1)
ls = jit(ls)
zXi = zerosRobust(zXi)
if timer:
import time
timer_f: Callable[[], float] = getattr(time, timerType)
ls(zXi, *args).block_until_ready()
start = timer_f()
xi = ls(zXi, *args).block_until_ready()
stop = timer_f()
zXi += xi
return zXi, stop - start
else:
zXi += ls(zXi, *args)
return zXi
class LsClass:
"""
JITed linear least-squares class.
Like the :func:`LS <tfc.utils.TFCUtils.LS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing.
See :func:`LS <tfc.utils.TFCUtils.LS>` for more details.
"""
def __init__(
self,
zXi: PyTree,
res: Callable,
*args: Any,
constant_arg_nums: list[int] = [],
J: Optional[Callable[..., np.ndarray]] = None,
method: Literal["pinv", "lstsq"] = "pinv",
timer: bool = False,
timerType: str = "process_time",
holomorphic: bool = False,
) -> None:
"""
Initialization function. Creates the JIT-ed least-squares function.
Parameters
----------
zXi : PyTree
Unknown parameters to be found using least-squares.
res : Callable
Residual function (also known as the loss function) with signature res(xi: PyTree, *args:Any, **kwargs:Any).
Note, the first argument does not need to be named xi, this is just illustrative.
*args : Any
Any additional arguments taken by res other than the first PyTree argument.
J : Optional[Callable[...,np.ndarray]]
User specified Jacobian function. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)
constant_arg_nums: list[int], optional
These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.
method : Literal["pinv","lstsq"], optional
Method for least-squares inversion. (Default value = "pinv")
* pinv - Use np.linalg.pinv
* lstsq - Use np.linalg.lstsq
timer : bool, optional
Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.
timerType : str, optional
Any timer from the time module. (Default value = "process_time")
holomorphic : bool, optional
Indicates whether residual function is promised to be holomorphic. (Default value = False)
"""
self.timerType = timerType
self.timer = timer
self.holomorphic = holomorphic
if isinstance(zXi, TFCDict) or isinstance(zXi, TFCDictRobust):
dictFlag = True
else:
dictFlag = False
if J is None:
if dictFlag:
if isinstance(zXi, TFCDictRobust):
def J(xi, *args):
jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
return np.hstack(
[
jacob[k].reshape(
jacob[k].shape[0], onp.prod(onp.array(xi[k].shape))
)
for k in xi.keys()
]
)
else:
def J(xi, *args):
jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
return np.hstack([jacob[k] for k in xi.keys()])
else:
J = lambda xi, *args: jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
if method == "pinv":
ls = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), -res(xi, *args))
elif method == "lstsq":
ls = lambda xi, *args: np.linalg.lstsq(J(xi, *args), -res(xi, *args), rcond=None)[0]
else:
TFCPrint.Error("The method entered is not valid. Please enter a valid method.")
if constant_arg_nums:
# Make arguments constant if desired
ls = pe(zXi, *args, constant_arg_nums=constant_arg_nums)(ls)
args: list[Any] = list(args)
constant_arg_nums.sort()
constant_arg_nums.reverse()
for k in constant_arg_nums:
args.pop(k - 1)
self._ls = jit(ls)
self._compiled = False
def run(self, zXi: PyTree, *args: Any) -> PyTree | tuple[PyTree, float]:
"""
Runs the JIT-ed least-squares function and times it if desired.
Parameters
----------
zXi : PyTree
Unknown parameters to be found using least-squares.
*args : Any
Any additional arguments taken by res other than xi.
Returns
-------
xi : PyTree
Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.
time : float, optional
Computation time as calculated by timerType specified. This output is only returned if timer = True.
"""
if self.timer:
import time
timer = getattr(time, self.timerType)
if not self._compiled:
self._ls(zXi, *args).block_until_ready()
self._compiled = True
start = timer()
xi = self._ls(zXi, *args).block_until_ready()
stop = timer()
zXi += xi
return zXi, stop - start
else:
zXi += ls(zXi, *args)
self._compiled = True
return zXi
def nlls_id_print(it: int, x, end: str = "\n"):
print("Iteration: {0}\tmax(abs(res)): {1}".format(it, x), end=end)
@overload
def NLLS(
xiInit: PyTree,
res: Callable,
*args: Any,
constant_arg_nums: list[int] = [],
J: Optional[Callable[..., np.ndarray]] = None,
cond: Optional[Callable[[PyTree], bool]] = None,
body: Optional[Callable[[PyTree], PyTree]] = None,
tol: float = 1e-13,
maxIter: uint = 50,
method: Literal["pinv", "lstsq"] = "pinv",
timer: Literal[False] = False,
printOut: bool = False,
printOutEnd: str = "\n",
timerType: str = "process_time",
holomorphic: bool = False,
) -> tuple[PyTree, int]: ...
@overload
def NLLS(
xiInit: PyTree,
res: Callable,
*args: Any,
constant_arg_nums: list[int] = [],
J: Optional[Callable[..., np.ndarray]] = None,
cond: Optional[Callable[[PyTree], bool]] = None,
body: Optional[Callable[[PyTree], PyTree]] = None,
tol: float = 1e-13,
maxIter: uint = 50,
method: Literal["pinv", "lstsq"] = "pinv",
timer: Literal[True] = True,
printOut: bool = False,
printOutEnd: str = "\n",
timerType: str = "process_time",
holomorphic: bool = False,
) -> tuple[PyTree, int, float]: ...
def NLLS(
xiInit: PyTree,
res: Callable,
*args: Any,
constant_arg_nums: list[int] = [],
J: Optional[Callable[..., np.ndarray]] = None,
cond: Optional[Callable[[PyTree], bool]] = None,
body: Optional[Callable[[PyTree], PyTree]] = None,
tol: float = 1e-13,
maxIter: uint = 50,
method: Literal["pinv", "lstsq"] = "pinv",
timer: bool = False,
printOut: bool = False,
printOutEnd: str = "\n",
timerType: str = "process_time",
holomorphic: bool = False,
) -> tuple[PyTree, int] | tuple[PyTree, int, float]:
"""
JIT-ed non-linear least squares.
This function takes in an initial guess, xiInit (initial values of xi), and a residual function, res, and
performs a nonlinear least squares to minimize the res function using the parameters
xi. The conditions on terminating the nonlinear least-squares are:
1. max(abs(res)) < tol
2. max(abs(dxi)) < tol, where dxi is the change in xi from the last iteration.
3. Number of iterations > maxIter.
Parameters
----------
xiInit : pytree or array-like
Initial guess for the unkown parameters.
res : function
Residual function (also known as the loss function) with signature res(xi,*args).
*args : iterable
Any additional arguments taken by res other than xi.
constant_arg_nums: list[int], optional
These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.
J : function
User specified Jacobian. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)
cond : Optional[Callable[[PyTree],bool]]
User specified condition function. If None, then the default cond function is used which checks the three termination criteria
provided in the class description. (Default value = None)
body : Optional[Callable[[PyTree],PyTree]]
User specified while-loop body function. If None, then use the default body function which updates xi using a NLLS interation and the method provided.
(Default value = None)
tol : float
Tolerance used in the default termination criteria: see class description for more details. (Default value = 1e-13)
maxIter : int, optional
Maximum number of iterations. (Default value = 50)
method : Literal["pinv","lstsq"], optional
Method for least-squares inversion. (Default value = "pinv")
* pinv - Use np.linalg.pinv
* lstsq - Use np.linalg.lstsq
timer : bool, optional
Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.
printOut : bool, optional
Controls whether the NLLS prints out information each interaton or not. The printout consists of the iteration and max(abs(res)) at each iteration. (Default value = False)
printOutEnd : str, optional
Value of keyword argument end passed to the print statement used in printOut. (Default value = "\\\\n")
timerType : str, optional
Any timer from the time module. (Default value = "process_time")
holomorphic : bool, optional
Indicates whether residual function is promised to be holomorphic. (Default value = False)
Returns
-------
xi : PyTree
Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.
it : int
Number of NLLS iterations performed.
time : float
Computation time as calculated by timerType specified. This output is only returned if timer = True.
"""
if timer and printOut:
TFCPrint.Warning(
"Warning, you have both the timer and printer on in the nonlinear least-squares.\nThe time will be longer than optimal due to the printout."
)
if isinstance(xiInit, TFCDict) or isinstance(xiInit, TFCDictRobust):
dictFlag = True
else:
dictFlag = False
def cond(val):
return np.all(
np.array(
[
np.max(np.abs(res(val["xi"], *val["args"]))) > tol,
val["it"] < maxIter,
np.max(np.abs(val["dxi"])) > tol,
]
)
)
if J is None:
if dictFlag:
if isinstance(xiInit, TFCDictRobust):
def J(xi, *args):
jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
return np.hstack(
[
jacob[k].reshape(jacob[k].shape[0], onp.prod(onp.array(xi[k].shape)))
for k in xi.keys()
]
)
else:
def J(xi, *args):
jacob = jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
return np.hstack([jacob[k] for k in xi.keys()])
else:
J = lambda xi, *args: jacfwd(res, 0, holomorphic=holomorphic)(xi, *args)
if method == "pinv":
LS = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), res(xi, *args))
elif method == "lstsq":
LS = lambda xi, *args: np.linalg.lstsq(J(xi, *args), res(xi, *args), rcond=None)[0]
else:
TFCPrint.Error("The method entered is not valid. Please enter a valid method.")
if constant_arg_nums:
# Make arguments constant if desired
LS = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(LS)
res = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(res)
args: list[Any] = list(args)
constant_arg_nums.sort()
constant_arg_nums.reverse()
for k in constant_arg_nums:
args.pop(k - 1)
if body is None:
if printOut:
def body(val):
val["dxi"] = LS(val["xi"], *val["args"])
val["xi"] -= val["dxi"]
io_callback(
partial(nlls_id_print, end=printOutEnd),
None,
val["it"],
np.max(np.abs(res(val["xi"], *val["args"]))),
)
val["it"] += 1
return val
else:
def body(val):
val["dxi"] = LS(val["xi"], *val["args"])
val["xi"] -= val["dxi"]
val["it"] += 1
return val
nlls = jit(lambda val: lax.while_loop(cond, body, val))
if dictFlag:
dxi = np.ones_like(cast(TFCDict | TFCDictRobust, xiInit).toArray())
else:
dxi = np.ones_like(xiInit)
if timer:
import time
timer_f: Callable[[], float] = getattr(time, timerType)
val = {"xi": xiInit, "dxi": dxi, "it": maxIter - 1, "args": args}
nlls(val)["dxi"].block_until_ready()
val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args}
start = timer_f()
val = nlls(val)
val["dxi"].block_until_ready()
stop = timer_f()
return val["xi"], val["it"], stop - start
else:
val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args}
val = nlls(val)
return val["xi"], val["it"]
class NllsClass:
"""
JITed nonlinear least squares class.
Like the :func:`NLLS <tfc.utils.TFCUtils.NLLS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing
"""
def __init__(
self,
xiInit: PyTree,
res: Callable,
*args: Any,
constant_arg_nums: list[int] = [],
J: Optional[Callable[..., np.ndarray]] = None,
cond: Optional[Callable[[PyTree], bool]] = None,
body: Optional[Callable[[PyTree], PyTree]] = None,
tol: float = 1e-13,
maxIter: uint = 50,
method: Literal["pinv", "lstsq"] = "pinv",
timer: bool = False,
printOut: bool = False,
printOutEnd: str = "\n",
timerType: str = "process_time",
holomorphic: bool = False,
) -> None:
"""
Initialization function. Creates the JIT-ed nonlinear least-squares function.
Parameters
----------
xiInit : pytree or array-like
Initial guess for the unkown parameters.
res : function
Residual function (also known as the loss function) with signature res(xi,*args).
*args : iterable
Any additional arguments taken by res other than xi.
constant_arg_nums: list[int], optional
These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.
J : function
User specified Jacobian. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)
cond : Optional[Callable[[PyTree],bool]]
User specified condition function. If None, then the default cond function is used which checks the three termination criteria
provided in the class description. (Default value = None)
body : Optional[Callable[[PyTree],PyTree]]
User specified while-loop body function. If None, then use the default body function which updates xi using a NLLS interation and the method provided.
(Default value = None)
tol : float
Tolerance used in the default termination criteria: see class description for more details. (Default value = 1e-13)
maxIter : int, optional
Maximum number of iterations. (Default value = 50)
method : Literal["pinv","lstsq"], optional
Method for least-squares inversion. (Default value = "pinv")
* pinv - Use np.linalg.pinv
* lstsq - Use np.linalg.lstsq
timer : bool, optional
Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.
printOut : bool, optional
Controls whether the NLLS prints out information each interaton or not. The printout consists of the iteration and max(abs(res)) at each iteration. (Default value = False)
printOutEnd : str, optional
Value of keyword argument end passed to the print statement used in printOut. (Default value = "\\\\n")
timerType : str, optional
Any timer from the time module. (Default value = "process_time")
holomorphic : bool, optional
Indicates whether residual function is promised to be holomorphic. (Default value = False)
"""
self.timerType = timerType
self.timer = timer
self._maxIter = maxIter
self.holomorphic = holomorphic
if timer and printOut:
TFCPrint.Warning(
"Warning, you have both the timer and printer on in the nonlinear least-squares.\nThe time will be longer than optimal due to the printout."
)
if isinstance(xiInit, TFCDict) or isinstance(xiInit, TFCDictRobust):
self._dictFlag = True
else:
self._dictFlag = False
def cond(val):
return np.all(
np.array(
[
np.max(np.abs(res(val["xi"], *val["args"]))) > tol,
val["it"] < maxIter,
np.max(np.abs(val["dxi"])) > tol,
]
)
)
if J is None:
if self._dictFlag:
if isinstance(xiInit, TFCDictRobust):
def J(xi, *args):
jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
return np.hstack(
[
jacob[k].reshape(
jacob[k].shape[0], onp.prod(onp.array(xi[k].shape))
)
for k in xi.keys()
]
)
else:
def J(xi, *args):
jacob = jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
return np.hstack([jacob[k] for k in xi.keys()])
else:
J = lambda xi, *args: jacfwd(res, 0, holomorphic=self.holomorphic)(xi, *args)
if method == "pinv":
LS = lambda xi, *args: np.dot(np.linalg.pinv(J(xi, *args)), res(xi, *args))
elif method == "lstsq":
LS = lambda xi, *args: np.linalg.lstsq(J(xi, *args), res(xi, *args), rcond=None)[0]
else:
TFCPrint.Error("The method entered is not valid. Please enter a valid method.")
if constant_arg_nums:
# Make arguments constant if desired
LS = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(LS)
res = pe(xiInit, *args, constant_arg_nums=constant_arg_nums)(res)
args: list[Any] = list(args)
constant_arg_nums.sort()
constant_arg_nums.reverse()
for k in constant_arg_nums:
args.pop(k - 1)
if body is None:
if printOut:
def body(val):
val["dxi"] = LS(val["xi"], *val["args"])
val["xi"] -= val["dxi"]
io_callback(
partial(nlls_id_print, end=printOutEnd),
None,
val["it"],
np.max(np.abs(res(val["xi"], *val["args"]))),
)
val["it"] += 1
return val
else:
def body(val):
val["dxi"] = LS(val["xi"], *val["args"])
val["xi"] -= val["dxi"]
val["it"] += 1
return val
self._nlls = jit(lambda val: lax.while_loop(cond, body, val))
self._compiled = False
def run(self, xiInit: PyTree, *args: Any) -> tuple[PyTree, int] | tuple[PyTree, int, float]:
"""Runs the JIT-ed nonlinear least-squares function and times it if desired.
Parameters
----------
xiInit : PyTree
Initial guess for the unkown parameters.
*args : Any
Any additional arguments taken by res other than xi.
Returns
-------
xi : PyTree
Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.
it : int
Number of NLLS iterations performed..
time : float
Computation time as calculated by timerType specified. This output is only returned if timer = True.
"""
if self._dictFlag:
dxi = np.ones_like(cast(TFCDict | TFCDictRobust, xiInit).toArray())
else:
dxi = np.ones_like(xiInit)
if self.timer:
import time
timer_f: Callable[[], float] = getattr(time, self.timerType)
if not self._compiled:
val = {"xi": xiInit, "dxi": dxi, "it": self._maxIter - 1, "args": args}
self._nlls(val)["dxi"].block_until_ready()
self._compiled = True
val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args}
start = timer_f()
val = self._nlls(val)
val["dxi"].block_until_ready()
stop = timer_f()
return val["xi"], val["it"], stop - start
else:
val = {"xi": xiInit, "dxi": dxi, "it": 0, "args": args}
val = self._nlls(val)
self._compiled = True
return val["xi"], val["it"]
class ComponentConstraintDict(TypedDict):
name: str
node0: str
node1: str
class ComponentConstraintGraph:
"""
Creates a graph of all valid ways in which component constraints can be embedded.
"""
def __init__(self, N: list[str], E: list[ComponentConstraintDict]) -> None:
"""
Class constructor.
Parameters
----------
N : list[str]
A list of strings that specify the node names. These node names typically coincide with
the names of the dependent variables.
E : list[ComponentConstraintDict]
The ComponentConstraintDict is a dictionary with the following fields:
* name - Name of the component constraint.
* node0 - The name of one of the nodes that makes up the component constraint. Must correspond with an element of the list given in N.
* node1 - The name of one of the nodes that makes up the component constraint. Must correspond with an element of the list given in N.
"""
# Check that all edges are connected to valid nodes
self.nNodes = len(N)
self.nEdges = len(E)
for k in range(self.nEdges):
if not (E[k]["node0"] in N and E[k]["node1"] in N):
TFCPrint.Error(
"Error either "
+ E[k]["node0"]
+ " or "
+ E[k]["node1"]
+ " is not a valid node. Make sure they appear in the nodes list."
)
# Create all possible source/target pairs. This tells whether node0 is the target or source, node1 will be the opposite.
import itertools
self.targets = list(itertools.product([0, 1], repeat=self.nEdges))
# Find all targets that are valid trees
self.goodTargets = []
for j in range(len(self.targets)):
adj = onp.zeros((self.nNodes, self.nNodes))
for k in range(self.nNodes):
kNode = N[k]
for g in range(self.nEdges):
if E[g]["node0"] == kNode:
if self.targets[j][g]:
adj[N.index(E[g]["node1"]), N.index(E[g]["node0"])] = 1.0
elif E[g]["node1"] == kNode:
if not self.targets[j][g]:
adj[N.index(E[g]["node0"]), N.index(E[g]["node1"])] = 1.0
if np.all(np.linalg.eigvals(adj) == 0.0):
self.goodTargets.append(j)
# Save nodes and edges for use later
self.N = N
self.E = E
def SaveGraphs(self, outputDir: Path, allGraphs: bool = False, savePDFs: bool = False) -> None:
"""
Saves the graphs.
The graphs are saved in a clickable HTML structure. They can also be saved as PDFs.
Parameters
----------
outputDir : Path
Output directory to save in.
allGraphs : bool, optional
Boolean that conrols whether all graphs are saved or just valid graphs. (Default value = False)
savePDFs : bool, optional
Boolean that controls whether the graphs are also saved as PDFs. (Default value = False)
"""
import os
from .Html import HTML, Dot
if allGraphs:
targets = self.targets
else:
targets = [self.targets[k] for k in self.goodTargets]
n = len(targets)
#: Create the main dot file
mainDot = Dot(os.path.join(outputDir, "dotFiles", "main"), "main")
mainDot.dot.node_attr.update(shape="box")
mainDot.dot.edge_attr.update(style="invis")
treeCnt = 0
for j in range(int(np.ceil(n / 5))):
if j != 0:
mainDot.dot.edge("tree" + str((j - 1) * 5), "tree" + str(j * 5))
with mainDot.dot.subgraph(name="subgraph" + str(j)) as c:
c.attr(rank="same")
for k in range(min(5, n - j * 5)):
c.node(
"tree" + str(treeCnt),
"Tree " + str(treeCnt),
href=os.path.join("htmlFiles", "tree" + str(treeCnt) + ".html"),
)
treeCnt += 1
mainDot.Render()
#: Create the main file HTML
mainHtml = HTML(os.path.join(outputDir, "main.html"))
with mainHtml.tag("html"):
with mainHtml.tag("body"):
with mainHtml.tag("style"):
mainHtml.doc.asis(mainHtml.centerClass)
mainHtml.doc.stag(
"img", src=os.path.join("dotFiles", "main.svg"), usemap="#main", klass="center"
)
mainHtml.doc.asis(
mainHtml.ReadFile(os.path.join(outputDir, "dotFiles", "main.cmapx"))
)
mainHtml.WriteFile()
#: Create the tree dot files
for k in range(n):
treeDot = Dot(os.path.join(outputDir, "dotFiles", "tree" + str(k)), "tree" + str(k))
treeDot.dot.attr(bgcolor="transparent")
treeDot.dot.node_attr.update(shape="box")
for j in range(self.nNodes):
treeDot.dot.node(self.N[j], self.N[j])
for j in range(self.nEdges):
if not targets[k][j]:
treeDot.dot.edge(
self.E[j]["node0"], self.E[j]["node1"], label=self.E[j]["name"]
)
else:
treeDot.dot.edge(
self.E[j]["node1"], self.E[j]["node0"], label=self.E[j]["name"]
)
if savePDFs:
treeDot.Render(formats=["cmapx", "svg", "pdf"])
else:
treeDot.Render()
#: Create the tree HTML files
for k in range(n):
treeHtml = HTML(os.path.join(outputDir, "htmlFiles", "tree" + str(k) + ".html"))
with treeHtml.tag("html"):
with treeHtml.tag("body"):
with treeHtml.tag("style"):
treeHtml.doc.asis(treeHtml.centerClass)
treeHtml.doc.stag(
"img",
src=os.path.join("..", "dotFiles", "tree" + str(k) + ".svg"),
usemap="#tree" + str(k),
klass="center",
)
treeHtml.doc.asis(
treeHtml.ReadFile(
os.path.join(outputDir, "dotFiles", "tree" + str(k) + ".cmapx")
)
)
treeHtml.WriteFile()
def ScaledQrLs(A: np.ndarray, B: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""This function performs least-squares using a scaled QR method.
Parameters
----------
A : np.ndarray
A matrix in A*x = B.
B : np.ndarray
B matrix in A*x = B.
Returns
-------
x : np.ndarray
Solution to A*x = B solved using a scaled QR method.
cn : np.ndarray
Condition number.
"""
S = 1.0 / np.sqrt(np.sum(A * A, 0))
S = np.reshape(S, (A.shape[1],))
q, r = np.linalg.qr(A.dot(np.diag(S)))
x: np.ndarray = S * np.linalg.multi_dot([_MatPinv(r), q.T, B])
cn: np.ndarray = cast(np.ndarray, np.linalg.cond(r))
return x, cn
def _MatPinv(A: np.ndarray) -> np.ndarray:
"""This function is used to better replicate MATLAB's pseudo-inverse.
Parameters
----------
A : np.ndarray
Matrix to be inverted.
Returns
-------
Ainv : np.ndarray
Inverse of A.
"""
rcond = onp.max(A.shape) * onp.spacing(np.linalg.norm(A, ord=2))
return np.linalg.pinv(A, rcond=rcond)
def step(x: np.ndarray) -> np.ndarray:
"""
This is the unit step function, but the deriative is defined and equal to 0 at every point.
Parameters
----------
x : np.ndarray
Array to apply step to.
Returns
-------
step_x : np.ndarray
step(x)
"""
return np.heaviside(x, 0)
def criuCheckpoint(dir: str = "criu_checkpoint", user_mode: bool = False):
"""
Use CRIU to create a checkpoint of your program.
WARNING: You must ensure no external sockets are used, i.e., via matplotlib.
WARNING: GPU memory cannot be mapped for all GPUs.
WARNING: user_mode is a work in progress and is not full featured yet.
Parameters
----------
dir : str
Directory where the CRIU checkpoint will be created.
user_mode : bool
Whether to use user mode or not. (Default value = False)
"""
import os
from pathlib import Path
path = Path(dir)
pid = os.getpid()
# Create path if it does not exist
if not os.path.exists(path):
os.mkdir(path)
# Run the command and print how to restart
if user_mode:
os.system(
f"(criu dump --unpriviledged -t {pid} -D {str(path.absolute())} --shell-job --leave-running &) &"
)
print(
f"Creating a checkpoint. To restart run: sudo criu restore -D {str(path.absolute())} --shell-job"
)
else:
os.system(
f"(sudo criu dump -t {pid} -D {str(path.absolute())} --shell-job --leave-running &) &"
)
print(
f"Creating a checkpoint. To restart run: sudo criu restore -D {str(path.absolute())} --shell-job"
)