Class LsClass#
Defined in File TFCUtils.py
Class Documentation#
- class LsClass#
JITed linear least-squares class. Like the :func:`LS <tfc.utils.TFCUtils.LS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing. See :func:`LS <tfc.utils.TFCUtils.LS>` for more details.
Public Functions
- __init__(self, PyTree zXi, Callable res, *Any args, list[int] constant_arg_nums=[], Optional[Callable[..., np.ndarray]] J=None, Literal["pinv", "lstsq"] method="pinv", bool timer=False, str timerType="process_time", bool holomorphic=False)#
Initialization function. Creates the JIT-ed least-squares function. Parameters ---------- zXi : PyTree Unknown parameters to be found using least-squares. res : Callable Residual function (also known as the loss function) with signature res(xi: PyTree, *args:Any, **kwargs:Any). Note, the first argument does not need to be named xi, this is just illustrative. *args : Any Any additional arguments taken by res other than the first PyTree argument. J : Optional[Callable[...,np.ndarray]] User specified Jacobian function. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None) constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details. method : Literal["pinv","lstsq"], optional Method for least-squares inversion. (Default value = "pinv") * pinv - Use np.linalg.pinv * lstsq - Use np.linalg.lstsq timer : bool, optional Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime. As one iteration of the non-linear least squares is run first to avoid timining the JAX trace. timerType : str, optional Any timer from the time module. (Default value = "process_time") holomorphic : bool, optional Indicates whether residual function is promised to be holomorphic. (Default value = False)
- run(self, PyTree zXi, *Any args)#
Runs the JIT-ed least-squares function and times it if desired. Parameters ---------- zXi : PyTree Unknown parameters to be found using least-squares. *args : Any Any additional arguments taken by res other than xi. Returns ------- xi : PyTree Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input. time : float, optional Computation time as calculated by timerType specified. This output is only returned if timer = True.