Class NllsClass#
Defined in File TFCUtils.py
Class Documentation#
- class NllsClass#
JITed nonlinear least squares class. Like the :func:`NLLS <tfc.utils.TFCUtils.NLLS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing
Public Functions
- __init__(self, PyTree xiInit, Callable res, *Any args, list[int] constant_arg_nums=[], Optional[Callable[..., np.ndarray]] J=None, Optional[Callable[[PyTree], bool]] cond=None, Optional[Callable[[PyTree], PyTree]] body=None, float tol=1e-13, uint maxIter=50, Literal["pinv", "lstsq"] method="pinv", bool timer=False, bool printOut=False, str printOutEnd="\n", str timerType="process_time", bool holomorphic=False)#
Initialization function. Creates the JIT-ed nonlinear least-squares function. Parameters ---------- xiInit : pytree or array-like Initial guess for the unkown parameters. res : function Residual function (also known as the loss function) with signature res(xi,*args). *args : iterable Any additional arguments taken by res other than xi. constant_arg_nums: list[int], optional These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details. J : function User specified Jacobian. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None) cond : Optional[Callable[[PyTree],bool]] User specified condition function. If None, then the default cond function is used which checks the three termination criteria provided in the class description. (Default value = None) body : Optional[Callable[[PyTree],PyTree]] User specified while-loop body function. If None, then use the default body function which updates xi using a NLLS interation and the method provided. (Default value = None) tol : float Tolerance used in the default termination criteria: see class description for more details. (Default value = 1e-13) maxIter : int, optional Maximum number of iterations. (Default value = 50) method : Literal["pinv","lstsq"], optional Method for least-squares inversion. (Default value = "pinv") * pinv - Use np.linalg.pinv * lstsq - Use np.linalg.lstsq timer : bool, optional Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime. As one iteration of the non-linear least squares is run first to avoid timining the JAX trace. printOut : bool, optional Controls whether the NLLS prints out information each interaton or not. The printout consists of the iteration and max(abs(res)) at each iteration. (Default value = False) printOutEnd : str, optional Value of keyword argument end passed to the print statement used in printOut. (Default value = "\\\\n") timerType : str, optional Any timer from the time module. (Default value = "process_time") holomorphic : bool, optional Indicates whether residual function is promised to be holomorphic. (Default value = False)
- run(self, PyTree xiInit, *Any args)#
Runs the JIT-ed nonlinear least-squares function and times it if desired. Parameters ---------- xiInit : PyTree Initial guess for the unkown parameters. *args : Any Any additional arguments taken by res other than xi. Returns ------- xi : PyTree Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input. it : int Number of NLLS iterations performed.. time : float Computation time as calculated by timerType specified. This output is only returned if timer = True.