Class NllsClass#

Class Documentation#

class NllsClass#
JITed nonlinear least squares class.
Like the :func:`NLLS <tfc.utils.TFCUtils.NLLS>` function, but it is in class form so that the run methd can be called multiple times without re-JITing

Public Functions

__init__(self, PyTree xiInit, Callable res, *Any args, list[int] constant_arg_nums=[], Optional[Callable[..., np.ndarray]] J=None, Optional[Callable[[PyTree], bool]] cond=None, Optional[Callable[[PyTree], PyTree]] body=None, float tol=1e-13, uint maxIter=50, Literal["pinv", "lstsq"] method="pinv", bool timer=False, bool printOut=False, str printOutEnd="\n", str timerType="process_time", bool holomorphic=False)#
Initialization function. Creates the JIT-ed nonlinear least-squares function.

Parameters
----------
xiInit : pytree or array-like
    Initial guess for the unkown parameters.

res : function
    Residual function (also known as the loss function) with signature res(xi,*args).

*args : iterable
    Any additional arguments taken by res other than xi.

constant_arg_nums: list[int], optional
    These arguments will be removed from the residual function and treated as constant. See :func:`pejit <tfc.utils.TFCUtils.pejit>` for more details.

J : function
     User specified Jacobian. If None, then the Jacobian of res with respect to xi will be calculated via automatic differentiation. (Default value = None)

cond : Optional[Callable[[PyTree],bool]]
     User specified condition function. If None, then the default cond function is used which checks the three termination criteria
     provided in the class description. (Default value = None)

body : Optional[Callable[[PyTree],PyTree]]
     User specified while-loop body function. If None, then use the default body function which updates xi using a NLLS interation and the method provided.
     (Default value = None)

tol : float
     Tolerance used in the default termination criteria: see class description for more details. (Default value = 1e-13)

maxIter : int, optional
     Maximum number of iterations. (Default value = 50)

method : Literal["pinv","lstsq"], optional
     Method for least-squares inversion. (Default value = "pinv")
     * pinv - Use np.linalg.pinv
     * lstsq - Use np.linalg.lstsq

timer : bool, optional
     Boolean that chooses whether to time the code or not. (Default value = False). Note that setting to true adds a slight increase in runtime.
     As one iteration of the non-linear least squares is run first to avoid timining the JAX trace.

printOut : bool, optional
     Controls whether the NLLS prints out information each interaton or not. The printout consists of the iteration and max(abs(res)) at each iteration. (Default value = False)

printOutEnd : str, optional
     Value of keyword argument end passed to the print statement used in printOut. (Default value = "\\\\n")

timerType : str, optional
     Any timer from the time module. (Default value = "process_time")

holomorphic : bool, optional
     Indicates whether residual function is promised to be holomorphic. (Default value = False)
run(self, PyTree xiInit, *Any args)#
Runs the JIT-ed nonlinear least-squares function and times it if desired.

Parameters
----------
xiInit : PyTree
    Initial guess for the unkown parameters.

*args : Any
    Any additional arguments taken by res other than xi.

Returns
-------
xi : PyTree
     Unknowns that minimize res as found via least-squares. Type will be the same as zXi specified in the input.

it : int
     Number of NLLS iterations performed..

time : float
     Computation time as calculated by timerType specified. This output is only returned if timer = True.

Public Members

timerType#
timer#
holomorphic#

Protected Attributes

_maxIter#
_dictFlag#
_nlls#
_compiled#