PEJIT (Partial Evaluation Just In Time)#
This notebook discusses and explains how to use the pejit function. This function works very similary to JAX’s jit function, but pejit can cache variables and inner-functions that depend only on those variables as compile-time constants. This can improve performance if the inner-functions are computationally intensive, and it also means that if the inner-functions are JAX primitives, they do not need JIT transforms in order to be JITed—because their return values are cached, so the
inner-functions themselves are never exposed to the JIT.
pejit function#
Positional input arguments#
The positional arguments to pejit should be the same type/size/shape as the positional arguments to the function. pejit uses these internally when tracing.
Note, the result of pejit is applied to a function, so the function that pejit will be applied to should be passed in as an argument to the result of pejit; alternatively, pejit can be used as a decorator. See the examples below for more details.
Optional input keyword arguments#
pejit takes the same optional keword arguments as JAX’s jit. See their page for more details. In addition, pejit take the constant_arg_nums keyword, which is a list of integers that define which position arguments in the function that is being JITed will be treated as compile-time constants: inner-functions that depend only on these constants will also be treated as compile time constants.
Outputs#
The outputs this function is a JITed function whose positional arguments specified by constant_arg_nums have been removed.
Traces of jit vs. pejit#
As an example, suppose we have a function like
and we want to treat \(x\) as a compile time constant, i.e., if we let \(y = H(x)\) then
To begin, let’s utilze some of the TFC basis functions for \(H(x)\).
[1]:
import jax.numpy as np
from tfc import utfc
# Define H(x) using Chebyshev orthogonal polynomials
tfc = utfc(3, 0, 2, x0=0.0, xf=1.0)
H = tfc.H
x = tfc.x
# Define an example xi
xi = np.ones(H(x).shape[1])
# Define f(x,xi)
f = lambda x,xi: np.dot(H(x),xi)
Under the hood, pejit uses pe to partially evaluate the function while treating certain arguments as constants. Below, the JAX traces of the original function, and the function after pe is applied are printed out so you can see the difference.
[2]:
from jax import make_jaxpr
from tfc.utils.TFCUtils import pe
jaxpr_f = make_jaxpr(f)
f_pe = pe(x,xi,constant_arg_nums=[0])(f)
jaxpr_f_pe = make_jaxpr(f_pe)
print("Original function")
print(jaxpr_f(x,xi))
print("\nFunction after pe")
print(jaxpr_f_pe(xi))
Original function
{ lambda ; a:f64[3] b:f64[3]. let
c:f64[3,3] = H[d=0 full=False] a
d:f64[3] = dot_general[
dimension_numbers=(((1,), (0,)), ((), ()))
precision=None
preferred_element_type=None
] c b
in (d,) }
Function after pe
{ lambda a:f64[3,3]; b:f64[3]. let
c:f64[3] = dot_general[
dimension_numbers=(((1,), (0,)), ((), ()))
precision=None
preferred_element_type=None
] a b
in (c,) }
The original function takes in two arguments, x and xi (called a and b in the trace). It has no known constants. It uses x to compute H(x) and stores the result in c. Finally, it takes the dot product between H(x) and xi and stores the result in d, which it returns.
The pe function takes in one argument xi (b in the trace). It has one stored constant H(x) (called a in the trace). It takes the dot product between H(x) and xi and stores the result in c, which it returns. Notice that this trace has pre-computed H(x) and stored the result in a, so x is not needed at runtime and H(x) is not run at run-time.
Let’s take a look at the compiled results of these functions. For the pe function, we will simply use pejit on f, which runs pe before jit. Alternatively, we could have take the function above and passed it through the regular jit and gotten the same result.
[3]:
from jax import jit
from tfc.utils import pejit
# JIT f(x,xi) using the regular jax.jit and print out the compiled code
print("JAX jit")
f_jit = jit(f)
f_jit_lowered = f_jit.lower(x,xi)
print(f_jit_lowered.compile().compiler_ir()[0].to_string())
# PEJIT f(x,xi) using tfc.utils.pejit and print out the compiled code
print("TFC pejit")
f_pejit = pejit(x, xi, constant_arg_nums=[0])(f)
f_pejit_lowered = f_pejit.lower(xi)
print(f_pejit_lowered.compile().compiler_ir()[0].to_string())
JAX jit
HloModule jit__lambda_.12, entry_computation_layout={(f64[3]{0},f64[3]{0})->f64[3]{0}}
ENTRY %main.14 (Arg_0.1: f64[3], Arg_1.2: f64[3]) -> f64[3] {
%constant.3 = s64[] constant(0)
%Arg_0.1 = f64[3]{0} parameter(0)
%constant.4 = pred[] constant(false)
%constant.8 = s64[] constant(3)
%custom-call.1 = f64[3,3]{1,0} custom-call(s64[] %constant.3, f64[3]{0} %Arg_0.1, s64[] %constant.3, pred[] %constant.4, s64[] %constant.8, /*index=5*/s64[] %constant.8), custom_call_target="BasisFunc0", metadata={op_name="custom-call.7"}
%Arg_1.2 = f64[3]{0} parameter(1)
ROOT %dot.13 = f64[3]{0} dot(f64[3,3]{1,0} %custom-call.1, f64[3]{0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_34918/1347355215.py" source_line=13}
}
TFC pejit
HloModule jit__lambda_.13, entry_computation_layout={(f64[3]{0})->f64[3]{0}}
ENTRY %main.4 (Arg_0.1: f64[3]) -> f64[3] {
%constant.2 = f64[3,3]{1,0} constant({ { 1, -1, 1 }, { 1, 0, -1 }, { 1, 1, 1 } })
%Arg_0.1 = f64[3]{0} parameter(0)
ROOT %dot.3 = f64[3]{0} dot(f64[3,3]{1,0} %constant.2, f64[3]{0} %Arg_0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_34918/1347355215.py" source_line=13}
}
The compliled outputs are bit harder to parse than the traces above. However, we can still see the same features:
The
TFC pejitresult is shorter than theJAX jitresult. This is because theTFC pejitresult is not computingH(x).The
TFC pejitresult has a stored constant (calledconstant.2) which is of size 3x3. This is the result ofH(x).
No need for JIT transforms#
As mentioned above, the inner-functions whose values are cached by pejit do not need a jit transform in order to be run through pejit. This means that we can still “compile” results that utilize these inner-functions. Moreover, these inner-functions can still utilize other JAX transforms. To illustrate, below is a simple function H(x) whose gradient transforms have been implemented, but whose JIT transform has not been defined. f(x,xi) uses the derivative of H(x) in its
calculation.
[4]:
import numpy as onp
from jax import core
from jax.extend.core import Primitive
from jax.interpreters import ad, batching
from tfc.utils import egrad
# Define a simple function with gradient transformations.
H_p = Primitive("H")
def H(x, d=0):
return H_p.bind(x, d=d)
# Implicit translation
def H_impl(x, d=0):
if d == 0:
return onp.vstack((x, x**2, x**3, x**4)).T
elif d == 1:
return onp.vstack((onp.ones_like(x), 2*x, 3*x**2, 4*x**3)).T
else:
raise ValueError("Derivatives beyond order 2 have not been implemented yet.")
H_p.def_impl(H_impl)
# Abstract evaluation
def H_abstract_eval(x, d=0):
dim1 = 4
if len(x.shape) == 0:
dims = (dim1,)
else:
dims = (x.shape[0], dim1)
return core.ShapedArray(dims, x.dtype)
H_p.def_abstract_eval(H_abstract_eval)
# Define batching translation
def H_batch(vec, batch, d=0):
return Hjax(*vec, d=d), batch[0]
# Define jacobain vector product
def H_jvp(arg_vals, arg_tans, d=0, full=False):
x = arg_vals[0]
dx = arg_tans[0]
if not (dx is ad.Zero):
if type(dx) is batching.BatchTracer:
flag = onp.any(dx.val != 0)
else:
flag = onp.any(dx != 0)
if flag:
if len(dx.shape) == 1:
out_tans = H(x, d=d + 1) * onp.expand_dims(dx, 1)
else:
out_tans = H(x, d=d + 1) * dx
else:
dim0 = x.shape[0]
dim1 = deg+1
out_tans = np.zeros((dim0, dim1),dtype=x.dtype)
return (H(x, d=d), out_tans)
ad.primitive_jvps[H_p] = H_jvp
# Define f(x,xi) that uses H(x)
x = np.array([1.,2.,3.])
xi = np.array([1.,2.,3.,4.])
dH = egrad(H)
def f(x,xi):
return x + np.dot(dH(x),xi)
f(x,xi)
[4]:
DeviceArray([ 31., 175., 529.], dtype=float64)
If we try and call a regular jit on f(x,xi), we will get an error related to the JIT transfrom for H(x) being undefined.
[5]:
f_jit = jit(f)
f_jit(x,xi)
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
/usr/lib/python3.10/runpy.py in _run_module_as_main(***failed resolving arguments***)
190 except _Error as exc:
--> 191 msg = "%s: %s" % (sys.executable, exc)
192 sys.exit(msg)
/usr/lib/python3.10/runpy.py in _run_code(***failed resolving arguments***)
74 loader = mod_spec.loader
---> 75 fname = mod_spec.origin
76 cached = mod_spec.cached
/usr/lib/python3/dist-packages/ipykernel_launcher.py in <module>
11 # This is added back by InteractiveShellApp.init_path()
---> 12 if sys.path[0] == '':
13 del sys.path[0]
/usr/lib/python3/dist-packages/traitlets/config/application.py in launch_instance(***failed resolving arguments***)
843 """
--> 844 app = cls.instance(**kwargs)
845 app.initialize(argv)
/usr/lib/python3/dist-packages/ipykernel/kernelapp.py in start(***failed resolving arguments***)
667 if self.trio_loop:
--> 668 from ipykernel.trio_runner import TrioRunner
669 tr = TrioRunner()
/usr/lib/python3/dist-packages/tornado/platform/asyncio.py in start(***failed resolving arguments***)
194 except (RuntimeError, AssertionError):
--> 195 old_loop = None # type: ignore
196 try:
/usr/lib/python3.10/asyncio/base_events.py in run_forever(***failed resolving arguments***)
593
--> 594 old_agen_hooks = sys.get_asyncgen_hooks()
595 sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
/usr/lib/python3.10/asyncio/base_events.py in _run_once(***failed resolving arguments***)
1859
-> 1860 event_list = self._selector.select(timeout)
1861 self._process_events(event_list)
/usr/lib/python3.10/asyncio/events.py in _run(***failed resolving arguments***)
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in dispatch_queue(***failed resolving arguments***)
460 try:
--> 461 await self.process_one()
462 except Exception:
/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in process_one(***failed resolving arguments***)
446 try:
--> 447 t, dispatch, args = self.msg_queue.get_nowait()
448 except (asyncio.QueueEmpty, QueueEmpty):
/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in dispatch_shell(***failed resolving arguments***)
333 self.shell_stream.flush(zmq.POLLOUT)
--> 334 return
335
/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in execute_request(***failed resolving arguments***)
633 self.log.error("Got bad msg: ")
--> 634 self.log.error("%s", parent)
635 return
/usr/lib/python3/dist-packages/ipykernel/ipkernel.py in do_execute(***failed resolving arguments***)
321
--> 322 if (
323 _asyncio_runner
/usr/lib/python3/dist-packages/ipykernel/zmqshell.py in run_cell(***failed resolving arguments***)
531 self._last_traceback = None
--> 532 return super().run_cell(*args, **kwargs)
533
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_cell(***failed resolving arguments***)
2913 try:
-> 2914 result = self._run_cell(
2915 raw_cell, store_history, silent, shell_futures)
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in _run_cell(***failed resolving arguments***)
2935 coro = self.run_cell_async(
-> 2936 raw_cell,
2937 store_history=store_history,
/usr/lib/python3/dist-packages/IPython/core/async_helpers.py in _pseudo_sync_runner(***failed resolving arguments***)
77 try:
---> 78 coro.send(None)
79 except StopIteration as exc:
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_cell_async(***failed resolving arguments***)
1 # -*- coding: utf-8 -*-
2 """Main IPython class."""
3
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_ast_nodes(***failed resolving arguments***)
3342 "please try to upgrade IPython and open a bug report with your case.")
-> 3343 if _async:
3344 # If interactivity is async the semantics of run_code are
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_code(***failed resolving arguments***)
3450 if async_ and sys.version_info < (3,8):
-> 3451 last_expr = (await self._async_exec(code_obj, self.user_ns))
3452 code = compile('last_expr', 'fake', "single")
/tmp/ipykernel_34918/1304517386.py in <module>
1 f_jit = jit(f)
----> 2 f_jit(x,xi)
/tmp/ipykernel_34918/2530243949.py in f(***failed resolving arguments***)
65 def f(x,xi):
---> 66 return x + np.dot(dH(x),xi)
67
~/.local/lib/python3.10/site-packages/tfc/utils/TFCUtils.py in wrapped(***failed resolving arguments***)
168 tans = tuple(
--> 169 [onesRobust(args[i]) if i == j else zerosRobust(args[i]) for i in range(len(args))]
170 )
/tmp/ipykernel_34918/2530243949.py in H(***failed resolving arguments***)
9 def H(x, d=0):
---> 10 return H_p.bind(x, d=d)
11
/tmp/ipykernel_34918/2530243949.py in H_jvp(***failed resolving arguments***)
43 if type(dx) is batching.BatchTracer:
---> 44 flag = onp.any(dx.val != 0)
45 else:
/tmp/ipykernel_34918/2530243949.py in H(***failed resolving arguments***)
9 def H(x, d=0):
---> 10 return H_p.bind(x, d=d)
11
JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'H' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
NotImplementedError Traceback (most recent call last)
/tmp/ipykernel_34918/1304517386.py in <module>
1 f_jit = jit(f)
----> 2 f_jit(x,xi)
[... skipping hidden 12 frame]
~/.local/lib/python3.10/site-packages/jax/interpreters/mlir.py in jaxpr_subcomp(ctx, jaxpr, tokens, consts, *args)
988 rule = xla_fallback_lowering(eqn.primitive)
989 else:
--> 990 raise NotImplementedError(
991 f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
992 f"found for platform {ctx.platform}")
NotImplementedError: MLIR translation rule for primitive 'H' not found for platform cpu
However, if we use pejit and set x as a compile time constant, the value of H(x) will be pre-computed, and its result will be cached and used by jit. Therefore, the transform for H(x) is never used nor needed.
[6]:
f_pejit = pejit(x,xi,constant_arg_nums=[0])(f)
f_pejit(xi)
[6]:
DeviceArray([ 31., 175., 529.], dtype=float64)
Ways to use pejit#
pejit can be used in the ways defined above or as a decorator.
[7]:
f_pejit_1 = pejit(x,xi,constant_arg_nums=[0])(f)
@pejit(x,xi,constant_arg_nums=[0])
def f_pejit_2(x,xi):
return x + np.dot(dH(x),xi)
print("Result of f_pejit_1")
print(f_pejit_1(xi))
print("\nResult of f_pejit_2")
print(f_pejit_2(xi))
Result of f_pejit_1
[ 31. 175. 529.]
Result of f_pejit_2
[ 31. 175. 529.]
Reminders#
pejitremoves the argumnent defined inconstant_arg_numsfrom the compiled function. Therefore, ifxoff(x,xi)is included inconstant_arg_nums, then thef_pejitwill only be a function ofxi, i.e.,f_pejit(xi).pejitcalculates the values associated with theconstant_arg_numswhen it is called. Therefore, ifxis changed in the Python code afterpejitis called, the values ofH(x)cached bypejitwill not be recomputed.