Basis function backends#
The TFC module comes equipped with C++ and Python versions of the basis function backends. For the utfc
and mtfc
modules, the basis function backend can be changed using the backend
keyword. The pros and cons of each backend are summarized in the table below:
Feature |
C++ backend |
Python backend |
---|---|---|
Number types suppported |
Doubles only |
Supports any number type that numpy supports: float, complex, etc. |
Derivative order supported |
Arbitrary order dervatives for most basis functions, but only up to 8th order derivatives for some. |
Arbitary order derivatives for all basis functions. |
Compiling with JIT |
Can be compiled with native JAX JIT. Optimization function can optimize on the variables used to compute the basis functions. |
Can only be compiled with |
Compiling on the GPU |
Supported via |
Supported via |
For the vast majority of applications, the C++
backend is sufficient. Plus, it is easier for a newer user to use the regular JAX jit
transform than pejit
, so C++
is the default backend. However, for more advanced applications, e.g., solving complex differential equations, a more robust version of the basis functions are needed, which is why the Python backend exists.
Using the backends#
Other than the JIT transform, the basis function backends function the same. They have the same API and can be used in the same way.
[1]:
import jax.numpy as np
from tfc import utfc
from tfc.utils import egrad
# Create two versions of the utfc class. One with a C++ backend and the other with a Python backend.
cppBackend = utfc(6, 0, 2, x0=0.0, xf=1.0)
pythonBackend = utfc(6, 0, 2, x0=0.0, xf=1.0, backend="Python")
# Get H and x
x = cppBackend.x
Hcpp = cppBackend.H
Hpython = pythonBackend.H
# Take a derivative and print the result
dHcpp = egrad(Hcpp)
dHpython = egrad(Hpython)
print("C++ result:")
print(dHcpp(x))
print("\nPython result:")
print(dHpython(x))
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
C++ result:
[[ 0. 2. -8. ]
[ 0. 2. -6.47213595]
[ 0. 2. -2.47213595]
[ 0. 2. 2.47213595]
[ 0. 2. 6.47213595]
[ 0. 2. 8. ]]
Python result:
[[ 0. 2. -8. ]
[ 0. 2. -6.47213595]
[ 0. 2. -2.47213595]
[ 0. 2. 2.47213595]
[ 0. 2. 6.47213595]
[ 0. 2. 8. ]]
When compiling the backends with JAX’s JIT, only the C++
backend can be compiled natively. If one wants to compile the Python backends, they must be cached as complile time constants using pejit
.
[2]:
from jax import jit
from tfc.utils import pejit
# Define xi for use in f
xi = np.ones(Hcpp(x).shape[1])
# Define the functions to be JITed
cpp_f = lambda x,xi: np.dot(dHcpp(x),xi)
python_f = lambda x,xi: np.dot(dHpython(x),xi)
# JIT the functions
cpp_f_jit = jit(cpp_f)
python_f_jit = pejit(x, xi, constant_arg_nums=[0])(python_f)
# Print the results
print("C++ backend result:")
print(cpp_f_jit(x,xi))
print("\nPython backend result:")
print(python_f_jit(xi))
C++ backend result:
[-6. -4.47213595 -0.47213595 4.47213595 8.47213595 10. ]
Python backend result:
[-6. -4.47213595 -0.47213595 4.47213595 8.47213595 10. ]
Notice, this means that in order to compile a function using a Python backend, there must not be a need to compute the basis function output at run time, i.e., the function the user wants to compile must be setup in such a way that the result of the Python backend basis function is known and can be cached at compile time. This is the case for differential equations, e.g., see the complex ODE tutorial and many other optimization problems, but is not true for all optimziation problems.