Basis function backends

Basis function backends#

The TFC module comes equipped with C++ and Python versions of the basis function backends. For the utfc and mtfc modules, the basis function backend can be changed using the backend keyword. The pros and cons of each backend are summarized in the table below:

Feature

C++ backend

Python backend

Number types suppported

Doubles only

Supports any number type that numpy supports: float, complex, etc.

Derivative order supported

Arbitrary order dervatives for most basis functions, but only up to 8th order derivatives for some.

Arbitary order derivatives for all basis functions.

Compiling with JIT

Can be compiled with native JAX JIT. Optimization function can optimize on the variables used to compute the basis functions.

Can only be compiled with pejit, and basis function must be able to be cached before compiling, i.e., inputs to the basis functions can not be optimized, but must be constant with resect the optimization problem. This is true for differential equations solved via TFC, but is not true for all optimziation problems. See the pejit tutorial for more details.

Compiling on the GPU

Supported via pejit only

Supported via pejit only

For the vast majority of applications, the C++ backend is sufficient. Plus, it is easier for a newer user to use the regular JAX jit transform than pejit, so C++ is the default backend. However, for more advanced applications, e.g., solving complex differential equations, a more robust version of the basis functions are needed, which is why the Python backend exists.

Using the backends#

Other than the JIT transform, the basis function backends function the same. They have the same API and can be used in the same way.

[1]:
import jax.numpy as np
from tfc import utfc
from tfc.utils import egrad

# Create two versions of the utfc class. One with a C++ backend and the other with a Python backend.
cppBackend = utfc(6, 0, 2, x0=0.0, xf=1.0)
pythonBackend = utfc(6, 0, 2, x0=0.0, xf=1.0, backend="Python")

# Get H and x
x = cppBackend.x
Hcpp = cppBackend.H
Hpython = pythonBackend.H

# Take a derivative and print the result
dHcpp = egrad(Hcpp)
dHpython = egrad(Hpython)

print("C++ result:")
print(dHcpp(x))
print("\nPython result:")
print(dHpython(x))
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
C++ result:
[[ 0.          2.         -8.        ]
 [ 0.          2.         -6.47213595]
 [ 0.          2.         -2.47213595]
 [ 0.          2.          2.47213595]
 [ 0.          2.          6.47213595]
 [ 0.          2.          8.        ]]

Python result:
[[ 0.          2.         -8.        ]
 [ 0.          2.         -6.47213595]
 [ 0.          2.         -2.47213595]
 [ 0.          2.          2.47213595]
 [ 0.          2.          6.47213595]
 [ 0.          2.          8.        ]]

When compiling the backends with JAX’s JIT, only the C++ backend can be compiled natively. If one wants to compile the Python backends, they must be cached as complile time constants using pejit.

[2]:
from jax import jit
from tfc.utils import pejit

# Define xi for use in f
xi = np.ones(Hcpp(x).shape[1])

# Define the functions to be JITed
cpp_f = lambda x,xi: np.dot(dHcpp(x),xi)
python_f = lambda x,xi: np.dot(dHpython(x),xi)

# JIT the functions
cpp_f_jit = jit(cpp_f)
python_f_jit = pejit(x, xi, constant_arg_nums=[0])(python_f)

# Print the results
print("C++ backend result:")
print(cpp_f_jit(x,xi))
print("\nPython backend result:")
print(python_f_jit(xi))
C++ backend result:
[-6.         -4.47213595 -0.47213595  4.47213595  8.47213595 10.        ]

Python backend result:
[-6.         -4.47213595 -0.47213595  4.47213595  8.47213595 10.        ]

Notice, this means that in order to compile a function using a Python backend, there must not be a need to compute the basis function output at run time, i.e., the function the user wants to compile must be setup in such a way that the result of the Python backend basis function is known and can be cached at compile time. This is the case for differential equations, e.g., see the complex ODE tutorial and many other optimization problems, but is not true for all optimziation problems.