
.. _program_listing_file_src_tfc_utils_CeSolver.py:

Program Listing for File CeSolver.py
====================================

|exhale_lsh| :ref:`Return to documentation for file <file_src_tfc_utils_CeSolver.py>` (``src/tfc/utils/CeSolver.py``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: py

   import sympy as sp
   from sympy import Expr
   from sympy.core.function import AppliedUndef
   from sympy.printing.pycode import PythonCodePrinter
   from sympy.simplify.simplify import nc_simplify
   from .tfc_types import ConstraintOperators, Exprs, Any, Literal, ConstraintOperator
   from .TFCUtils import TFCPrint
   
   
   class CeSolver:
       """
       Constrained expression solver.
   
       This class solves constrained expressions for you.
   
       Parameters
       ----------
       C : ConstraintOperators
           This is a tuple or list constraint operators. Each element in the
           iterable should be a Python function that takes in a sympy function,
           such as `g(x)`, and outputs that function evaluated in the same way
           as the function in the constraint. For example, if the constraint was
           u(3) = 0, then the assocaited constraint operator would be:
           `lambda u: = u.subs(x,3)`.
       kappa : Exprs
           This is a tuple or list of the kappa portion of each constraint. For the
           example u(3) = 0, the kappa portion is simply 0. Note, the standard
           Python 0 should not be used, but rather, sympy.re(0).
       s : Exprs
           This is a tuple or list of the support functions. These should be given
           in terms of sympy symbols and constants. For example, if we wanted to
           use the constant function x = 1 as a support function, then we would
           use sympy.re(1) in this iterable.
       g : AppliedUndef| Any
           This is the free function used in the constrained expression. For example,
           `g(x)`.
   
       References
       ----------
       The algorithm used here is given at 26:13 of this video: https://www.youtube.com/watch?v=uisOZVBHA2U&t=1573s
   
       Examples
       --------
       Consider the constraints `u(0) = 2` and `u_x(2) = 1` where `u_x` is the derivative of `u(x)` with respect to `x`.
       Moreover, suppose we want to use `g(x)` as the free function.
   
       import sympy as sp
       from tfc.utils import CeSolver
   
       x = sp.Symbol("x")
       y = sp.Symbol("y")
       u = sp.Function("u")
       g = sp.Function("g")
   
       C = [lambda u: u.subs(x,0), lambda u: sp.diff(u,x).subs(x,2)]
       K = [sp.re(2), sp.re(1)]
       s = [sp.re(1), x]
   
       cs = CeSolver(C,K,s, g(x))
       ce = cs.ce
   
       In the above code example, `ce` is the constrained expression that satisfies these constraints.
       """
   
       def __init__(self, C: ConstraintOperators, kappa: Exprs, s: Exprs, g: AppliedUndef | Any):
           self._C = C
           self._K = kappa
           self._s = s
           self._g = g
           self._ce_stale: bool = True
           self._S_stale: bool = True
           self._alpha_stale: bool = True
           self._phi_stale: bool = True
           self._rho_stale: bool = True
   
       @property
       def print_type(self) -> Literal["tfc", "pretty", "latex", "str"]:
           return self._print_type
   
       @print_type.setter
       def print_type(self, print_type: Literal["tfc", "pretty", "latex", "str"]) -> None:
           from sympy import init_printing
   
           self._print_type = print_type
           if self._print_type == "tfc":
               tfc_printer = TfcPrinter()
               init_printing(pretty_print=True, pretty_printer=tfc_printer.doprint)
           elif self._print_type == "str":
               init_printing(pretty_print=False)
           elif self._print_type == "pretty":
               init_printing()
           elif self._print_type == "latex":
               from sympy import latex
   
               init_printing(pretty_print=True, pretty_printer=latex)
           else:
               TFCPrint.Error(
                   f'print_type was specified as {print_type} but only "tfc", "pretty", "latex", and "str" are accepted.'
               )
   
       @property
       def ce(self) -> Any:
           """
           Constrained expression.
   
           Returns
           -------
           Any
               Constrained expression.
           """
           if self._ce_stale:
               self._solveCe()
               self._ce_stale = False
           return self._ce
   
       @ce.setter
       def ce(self, ce: Any) -> None:
           """
           Sets the constrained expression to the user-supplied value.
   
           Parameters
           ----------
           ce: Any
               Sympy representation of the constrained expression.
           """
           self._ce_stale = False
           self._ce = ce
   
       @property
       def phi(self) -> Any:
           """
           Switching functions.
   
           Returns
           -------
           Any
               Switching functions.
           """
           if self._phi_stale:
               s_vec = sp.Matrix([s for s in self._s])
               self._phi = s_vec.transpose() * self.alpha
               self._phi_stale = False
           return self._phi
   
       @phi.setter
       def phi(self, phi: Any):
           """
           Set the switching functions.
   
           Parameters
           ----------
           phi : Any
               The switching functions.
           """
           self._phi = phi
           self._phi_stale = False
   
       @property
       def alpha(self) -> sp.Matrix:
           """
           Alpha matrix (inverse of the support matrix)
   
           Returns
           sp.Matrix
               alpha matrix. The elements are on the field over which
               the constrained expression is defined.
           """
           if self._alpha_stale:
               self._alpha = self.S.inv()
               self._alpha_stale = False
           return self._alpha
   
       @property
       def S(self) -> sp.Matrix:
           """
           Support matrix.
   
           Returns
           sp.Matrix
               Support matrix. The elements are on the field over which
               the constrained expression is defined.
           """
   
           def _applyC(c: ConstraintOperator, s) -> Any:
               """
               Apply the constraint operator to the switching function.
   
               Parameters
               ----------
               c : ConstraintOperator
                   Constraint operator.
               s : Expr
                   Switching function.
   
               Returns
               -------
               Any
                   c(s), which is a number on the field over which the
                   constrained expression is defined.
               """
   
               dark = c(s)
               if isinstance(dark, sp.Matrix) or isinstance(dark, sp.MatMul):
                   dark = nc_simplify(dark)[0]
               return dark
   
           if self._S_stale:
               self._S = sp.Matrix([[_applyC(c, s) for s in self._s] for c in self._C])
               self._S_stale = False
           return self._S
   
       @property
       def rho(self) -> Any:
           """
           Projection functionals.
   
           Returns
           -------
           Any
               Projection functionals.
           """
           if self._rho_stale:
               self._rho = sp.Matrix(
                   [sp.Add(kappa, -self._C[k](self._g)) for k, kappa in enumerate(self._K)]
               )
           return self._rho
   
       @property
       def s(self) -> Exprs:
           """
           Switching functions.
   
           Returns
           -------
           Exprs
               Support functions.
           """
           return self._s
   
       @s.setter
       def s(self, s: Exprs) -> None:
           """
           Set the support functions.
   
           Parameters
           ----------
           s : Exprs
               This is a tuple or list of the support functions. These should be given
               in terms of sympy symbols and constants. For example, if we wanted to
               use the constant function x = 1 as a support function, then we would
               use sympy.re(1) in this iterable.
           """
           self._s = s
           self._S_stale = True
           self._alpha_stale = True
           self._phi_stale = True
           self._ce_stale = True
   
       @property
       def kappa(self):
           """
           Kappa values.
   
           Returns
           -------
           Exprs
               Kappa values.
           """
           return self._K
   
       @kappa.setter
       def kappa(self, kappa: Exprs) -> None:
           """
           Set the kappa values.
   
           Parameters
           -------
           kappa : Exprs
               This is a tuple or list of the kappa portion of each constraint. For the
               example u(3) = 0, the kappa portion is simply 0. Note, the standard
               Python 0 should not be used, but rather, sympy.re(0).
           """
           self._K = kappa
           self._rho_stale = True
           self._ce_stale = True
   
       @property
       def C(self) -> ConstraintOperators:
           """
           Constraint operators.
   
           Returns
           -------
           ConstraintOperators
               Constraint operators.
           """
           return self._C
   
       @C.setter
       def C(self, C: ConstraintOperators) -> None:
           """
           Parameters
           ----------
           C : ConstraintOperators
               This is a tuple or list constraint operators. Each element in the
               iterable should be a Python function that takes in a sympy function,
               such as `g(x)`, and outputs that function evaluated in the same way
               as the function in the constraint. For example, if the constraint was
               u(3) = 0, then the assocaited constraint operator would be:
               `lambda u: = u.subs(x,3)`.
           """
           self._C = C
           self._S_stale = True
           self._alpha_stale = True
           self._phi_stale = True
           self._rho_stale = True
           self._ce_stale = True
   
       @property
       def g(self) -> AppliedUndef | Any:
           """
           Free function.
   
           Returns
           -------
           AppliedUndef | Any
               Free function.
           """
           return self._g
   
       @g.setter
       def g(self, g: AppliedUndef | Any) -> None:
           """
           Set the free function.
   
           Parameters
           ----------
           g : AppliedUndef | Any
               This is the free function used in the constrained expression. For example,
               `g(x)`.
           """
           self._g = g
           self._ce_stale = True
           self._rho_stale = True
   
       def _solveCe(self) -> None:
           """
           Solves the constrained expression and stores it in self.ce
           """
           self._ce = sp.Add(self.gg, (self.phiphi * self.rho)[0])
   
       def checkCe(self) -> bool:
           """
           Checks the constrained expression stored in the class against the stored constraints.
   
           Return
           ------
           bool:
               Returns True if the constraint expression satisfies the constraints and false otherwise.
           """
           checks = [sp.simplify(c(self.cece)) == sp.simplify(k) for c, k in zip(self._C, self._K)]
           ret = True
           for k, check in enumerate(checks):
               if not check:
                   TFCPrint.Warning(
                       f"Expected result of constraint {k+1} to be {self._K[k]}, but got {self._C[k](self.ce)}."
                   )
                   ret = False
           return ret
   
   
   class TfcPrinter(PythonCodePrinter):
       def __init__(self, settings=None):
           # Switch math to numpy
           for k, v in self._kf.items():
               if "math" in v:
                   self._kf[k] = v.replace("math", "np")
           for k, v in self._kc.items():
               if "math" in v:
                   self._kc[k] = v.replace("math", "np")
   
           super().__init__(settings=settings)
   
       def _hprint_Pow(self, expr: Expr, rational: bool = False, sqrt: str = "np.sqrt"):
           """
           Override _hprint_Pow to use np.sqrt rather than math.sqrt
   
           Parameters
           ----------
           expr : Expr
               Expression to print.
           rational : bool
               Whether the expression is rational.
           sqrt : str
               String to print for the sqrt.
   
           Returns
           -------
           str
               String to print.
           """
   
           return super()._hprint_Pow(expr, rational=rational, sqrt=sqrt)
   
       def _print_Symbol(self, expr: Expr) -> str:
           """
           Add in Symbol printing function.
   
           Parameters
           ----------
           expr : Expr
               Symbol to print.
   
           Returns
           -------
           str
               String to print.
           """
           return self._print(str(expr))
   
       def _print_Function(self, expr: Expr) -> str:
           """
           Add in Function printing function.
   
           Parameters
           ----------
           expr : Expr
               Function to print.
   
           Returns
           -------
           str
               String to print.
           """
           return self._print(str(expr))
   
       def _print_Subs(self, subs: Expr) -> str:
           """
           Substitute values.
   
           Parameters
           ----------
           subs : Expr
               Substitution(s).
   
           Returns
           -------
           str
               String to print.
           """
           expr, old, new = subs.args
           expr = self._print(expr)
           for k in range(len(old)):
               expr = expr.replace(self._print(old[k]), self._print(new[k]))
           return expr
   
       def _print_Derivative(self, expr: Expr) -> str:
           """
           Add in derivative printing function that uses egrad.
   
           Parameters
           ----------
           expr : Expr
               Expression to print.
   
           Returns
           -------
           str
               String to print.
           """
           # Function will be the full function, e.g., g(x,y)
           # vars will be the derivative symbol and order.
           # For example, dg/dx will have vars [(x,1)]
           function, *vars = expr.args
   
           # Find position for each derivative
           function_vars = function.args
           position_vars = []
           for var in vars:
               ind = function_vars.index(var[0])
               position_vars.append((ind, var[1]))
   
           # If you want the printer to work correctly for nested
           # expressions then use self._print() instead of str() or latex().
           # See the example of nested modulo below in the custom printing
           # method section.
           name = function.func.__name__
           parenthesis_counter = 0
           ret = ""
           for pv in position_vars:
               for _ in range(pv[1]):
                   ret += self._print("egrad(" + name + "," + str(pv[0]))
                   parenthesis_counter += 1
           for _ in range(parenthesis_counter):
               ret += self._print(")")
           ret += "(" + "".join(self._print(i[0]) for i in vars) + ")"
   
           return ret
