Source code for tfc.utils.CeSolver

import sympy as sp
from sympy import Expr
from sympy.core.function import AppliedUndef
from sympy.printing.pycode import PythonCodePrinter
from sympy.simplify.simplify import nc_simplify
from .tfc_types import ConstraintOperators, Exprs, Any, Literal, ConstraintOperator
from .TFCUtils import TFCPrint


[docs] class CeSolver: """ Constrained expression solver. This class solves constrained expressions for you. Parameters ---------- C : ConstraintOperators This is a tuple or list constraint operators. Each element in the iterable should be a Python function that takes in a sympy function, such as `g(x)`, and outputs that function evaluated in the same way as the function in the constraint. For example, if the constraint was u(3) = 0, then the assocaited constraint operator would be: `lambda u: = u.subs(x,3)`. kappa : Exprs This is a tuple or list of the kappa portion of each constraint. For the example u(3) = 0, the kappa portion is simply 0. Note, the standard Python 0 should not be used, but rather, sympy.re(0). s : Exprs This is a tuple or list of the support functions. These should be given in terms of sympy symbols and constants. For example, if we wanted to use the constant function x = 1 as a support function, then we would use sympy.re(1) in this iterable. g : AppliedUndef| Any This is the free function used in the constrained expression. For example, `g(x)`. References ---------- The algorithm used here is given at 26:13 of this video: https://www.youtube.com/watch?v=uisOZVBHA2U&t=1573s Examples -------- Consider the constraints `u(0) = 2` and `u_x(2) = 1` where `u_x` is the derivative of `u(x)` with respect to `x`. Moreover, suppose we want to use `g(x)` as the free function. import sympy as sp from tfc.utils import CeSolver x = sp.Symbol("x") y = sp.Symbol("y") u = sp.Function("u") g = sp.Function("g") C = [lambda u: u.subs(x,0), lambda u: sp.diff(u,x).subs(x,2)] K = [sp.re(2), sp.re(1)] s = [sp.re(1), x] cs = CeSolver(C,K,s, g(x)) ce = cs.ce In the above code example, `ce` is the constrained expression that satisfies these constraints. """ def __init__(self, C: ConstraintOperators, kappa: Exprs, s: Exprs, g: AppliedUndef | Any): self._C = C self._K = kappa self._s = s self._g = g self._ce_stale: bool = True self._S_stale: bool = True self._alpha_stale: bool = True self._phi_stale: bool = True self._rho_stale: bool = True @property def print_type(self) -> Literal["tfc", "pretty", "latex", "str"]: return self._print_type @print_type.setter def print_type(self, print_type: Literal["tfc", "pretty", "latex", "str"]) -> None: from sympy import init_printing self._print_type = print_type if self._print_type == "tfc": tfc_printer = TfcPrinter() init_printing(pretty_print=True, pretty_printer=tfc_printer.doprint) elif self._print_type == "str": init_printing(pretty_print=False) elif self._print_type == "pretty": init_printing() elif self._print_type == "latex": from sympy import latex init_printing(pretty_print=True, pretty_printer=latex) else: TFCPrint.Error( f'print_type was specified as {print_type} but only "tfc", "pretty", "latex", and "str" are accepted.' ) @property def ce(self) -> Any: """ Constrained expression. Returns ------- Any Constrained expression. """ if self._ce_stale: self._solveCe() self._ce_stale = False return self._ce @ce.setter def ce(self, ce: Any) -> None: """ Sets the constrained expression to the user-supplied value. Parameters ---------- ce: Any Sympy representation of the constrained expression. """ self._ce_stale = False self._ce = ce @property def phi(self) -> Any: """ Switching functions. Returns ------- Any Switching functions. """ if self._phi_stale: s_vec = sp.Matrix([s for s in self._s]) self._phi = s_vec.transpose() * self.alpha self._phi_stale = False return self._phi @phi.setter def phi(self, phi: Any): """ Set the switching functions. Parameters ---------- phi : Any The switching functions. """ self._phi = phi self._phi_stale = False @property def alpha(self) -> sp.Matrix: """ Alpha matrix (inverse of the support matrix) Returns sp.Matrix alpha matrix. The elements are on the field over which the constrained expression is defined. """ if self._alpha_stale: self._alpha = self.S.inv() self._alpha_stale = False return self._alpha @property def S(self) -> sp.Matrix: """ Support matrix. Returns sp.Matrix Support matrix. The elements are on the field over which the constrained expression is defined. """ def _applyC(c: ConstraintOperator, s) -> Any: """ Apply the constraint operator to the switching function. Parameters ---------- c : ConstraintOperator Constraint operator. s : Expr Switching function. Returns ------- Any c(s), which is a number on the field over which the constrained expression is defined. """ dark = c(s) if isinstance(dark, sp.Matrix) or isinstance(dark, sp.MatMul): dark = nc_simplify(dark)[0] return dark if self._S_stale: self._S = sp.Matrix([[_applyC(c, s) for s in self._s] for c in self._C]) self._S_stale = False return self._S @property def rho(self) -> Any: """ Projection functionals. Returns ------- Any Projection functionals. """ if self._rho_stale: self._rho = sp.Matrix( [sp.Add(kappa, -self._C[k](self._g)) for k, kappa in enumerate(self._K)] ) return self._rho @property def s(self) -> Exprs: """ Switching functions. Returns ------- Exprs Support functions. """ return self._s @s.setter def s(self, s: Exprs) -> None: """ Set the support functions. Parameters ---------- s : Exprs This is a tuple or list of the support functions. These should be given in terms of sympy symbols and constants. For example, if we wanted to use the constant function x = 1 as a support function, then we would use sympy.re(1) in this iterable. """ self._s = s self._S_stale = True self._alpha_stale = True self._phi_stale = True self._ce_stale = True @property def kappa(self): """ Kappa values. Returns ------- Exprs Kappa values. """ return self._K @kappa.setter def kappa(self, kappa: Exprs) -> None: """ Set the kappa values. Parameters ------- kappa : Exprs This is a tuple or list of the kappa portion of each constraint. For the example u(3) = 0, the kappa portion is simply 0. Note, the standard Python 0 should not be used, but rather, sympy.re(0). """ self._K = kappa self._rho_stale = True self._ce_stale = True @property def C(self) -> ConstraintOperators: """ Constraint operators. Returns ------- ConstraintOperators Constraint operators. """ return self._C @C.setter def C(self, C: ConstraintOperators) -> None: """ Parameters ---------- C : ConstraintOperators This is a tuple or list constraint operators. Each element in the iterable should be a Python function that takes in a sympy function, such as `g(x)`, and outputs that function evaluated in the same way as the function in the constraint. For example, if the constraint was u(3) = 0, then the assocaited constraint operator would be: `lambda u: = u.subs(x,3)`. """ self._C = C self._S_stale = True self._alpha_stale = True self._phi_stale = True self._rho_stale = True self._ce_stale = True @property def g(self) -> AppliedUndef | Any: """ Free function. Returns ------- AppliedUndef | Any Free function. """ return self._g @g.setter def g(self, g: AppliedUndef | Any) -> None: """ Set the free function. Parameters ---------- g : AppliedUndef | Any This is the free function used in the constrained expression. For example, `g(x)`. """ self._g = g self._ce_stale = True self._rho_stale = True def _solveCe(self) -> None: """ Solves the constrained expression and stores it in self.ce """ self._ce = sp.Add(self.g, (self.phi * self.rho)[0])
[docs] def checkCe(self) -> bool: """ Checks the constrained expression stored in the class against the stored constraints. Return ------ bool: Returns True if the constraint expression satisfies the constraints and false otherwise. """ checks = [sp.simplify(c(self.ce)) == sp.simplify(k) for c, k in zip(self._C, self._K)] ret = True for k, check in enumerate(checks): if not check: TFCPrint.Warning( f"Expected result of constraint {k+1} to be {self._K[k]}, but got {self._C[k](self.ce)}." ) ret = False return ret
[docs] class TfcPrinter(PythonCodePrinter): def __init__(self, settings=None): # Switch math to numpy for k, v in self._kf.items(): if "math" in v: self._kf[k] = v.replace("math", "np") for k, v in self._kc.items(): if "math" in v: self._kc[k] = v.replace("math", "np") super().__init__(settings=settings) def _hprint_Pow(self, expr: Expr, rational: bool = False, sqrt: str = "np.sqrt"): """ Override _hprint_Pow to use np.sqrt rather than math.sqrt Parameters ---------- expr : Expr Expression to print. rational : bool Whether the expression is rational. sqrt : str String to print for the sqrt. Returns ------- str String to print. """ return super()._hprint_Pow(expr, rational=rational, sqrt=sqrt) def _print_Symbol(self, expr: Expr) -> str: """ Add in Symbol printing function. Parameters ---------- expr : Expr Symbol to print. Returns ------- str String to print. """ return self._print(str(expr)) def _print_Function(self, expr: Expr) -> str: """ Add in Function printing function. Parameters ---------- expr : Expr Function to print. Returns ------- str String to print. """ return self._print(str(expr)) def _print_Subs(self, subs: Expr) -> str: """ Substitute values. Parameters ---------- subs : Expr Substitution(s). Returns ------- str String to print. """ expr, old, new = subs.args expr = self._print(expr) for k in range(len(old)): expr = expr.replace(self._print(old[k]), self._print(new[k])) return expr def _print_Derivative(self, expr: Expr) -> str: """ Add in derivative printing function that uses egrad. Parameters ---------- expr : Expr Expression to print. Returns ------- str String to print. """ # Function will be the full function, e.g., g(x,y) # vars will be the derivative symbol and order. # For example, dg/dx will have vars [(x,1)] function, *vars = expr.args # Find position for each derivative function_vars = function.args position_vars = [] for var in vars: ind = function_vars.index(var[0]) position_vars.append((ind, var[1])) # If you want the printer to work correctly for nested # expressions then use self._print() instead of str() or latex(). # See the example of nested modulo below in the custom printing # method section. name = function.func.__name__ parenthesis_counter = 0 ret = "" for pv in position_vars: for _ in range(pv[1]): ret += self._print("egrad(" + name + "," + str(pv[0])) parenthesis_counter += 1 for _ in range(parenthesis_counter): ret += self._print(")") ret += "(" + "".join(self._print(i[0]) for i in vars) + ")" return ret