PEJIT (Partial Evaluation Just In Time)#
This notebook discusses and explains how to use the pejit
function. This function works very similary to JAX’s jit
function, but pejit
can cache variables and inner-functions that depend only on those variables as compile-time constants. This can improve performance if the inner-functions are computationally intensive, and it also means that if the inner-functions are JAX primitives, they do not need JIT transforms in order to be JITed—because their return values are cached, so the
inner-functions themselves are never exposed to the JIT.
pejit
function#
Positional input arguments#
The positional arguments to pejit
should be the same type/size/shape as the positional arguments to the function. pejit
uses these internally when tracing.
Note, the result of pejit
is applied to a function, so the function that pejit
will be applied to should be passed in as an argument to the result of pejit
; alternatively, pejit
can be used as a decorator. See the examples below for more details.
Optional input keyword arguments#
pejit
takes the same optional keword arguments as JAX’s jit
. See their page for more details. In addition, pejit
take the constant_arg_nums
keyword, which is a list of integers that define which position arguments in the function that is being JITed will be treated as compile-time constants: inner-functions that depend only on these constants will also be treated as compile time constants.
Outputs#
The outputs this function is a JITed function whose positional arguments specified by constant_arg_nums
have been removed.
Traces of jit vs. pejit#
As an example, suppose we have a function like
and we want to treat \(x\) as a compile time constant, i.e., if we let \(y = H(x)\) then
To begin, let’s utilze some of the TFC basis functions for \(H(x)\).
[1]:
import jax.numpy as np
from tfc import utfc
# Define H(x) using Chebyshev orthogonal polynomials
tfc = utfc(3, 0, 2, x0=0.0, xf=1.0)
H = tfc.H
x = tfc.x
# Define an example xi
xi = np.ones(H(x).shape[1])
# Define f(x,xi)
f = lambda x,xi: np.dot(H(x),xi)
Under the hood, pejit
uses pe
to partially evaluate the function while treating certain arguments as constants. Below, the JAX traces of the original function, and the function after pe
is applied are printed out so you can see the difference.
[2]:
from jax import make_jaxpr
from tfc.utils.TFCUtils import pe
jaxpr_f = make_jaxpr(f)
f_pe = pe(x,xi,constant_arg_nums=[0])(f)
jaxpr_f_pe = make_jaxpr(f_pe)
print("Original function")
print(jaxpr_f(x,xi))
print("\nFunction after pe")
print(jaxpr_f_pe(xi))
Original function
{ lambda ; a:f64[3] b:f64[3]. let
c:f64[3,3] = H[d=0 full=False] a
d:f64[3] = dot_general[
dimension_numbers=(((1,), (0,)), ((), ()))
precision=None
preferred_element_type=None
] c b
in (d,) }
Function after pe
{ lambda a:f64[3,3]; b:f64[3]. let
c:f64[3] = dot_general[
dimension_numbers=(((1,), (0,)), ((), ()))
precision=None
preferred_element_type=None
] a b
in (c,) }
The original function takes in two arguments, x
and xi
(called a
and b
in the trace). It has no known constants. It uses x
to compute H(x)
and stores the result in c
. Finally, it takes the dot product between H(x)
and xi
and stores the result in d
, which it returns.
The pe
function takes in one argument xi
(b
in the trace). It has one stored constant H(x)
(called a
in the trace). It takes the dot product between H(x)
and xi
and stores the result in c
, which it returns. Notice that this trace has pre-computed H(x)
and stored the result in a
, so x
is not needed at runtime and H(x)
is not run at run-time.
Let’s take a look at the compiled results of these functions. For the pe
function, we will simply use pejit
on f
, which runs pe
before jit
. Alternatively, we could have take the function above and passed it through the regular jit
and gotten the same result.
[3]:
from jax import jit
from tfc.utils import pejit
# JIT f(x,xi) using the regular jax.jit and print out the compiled code
print("JAX jit")
f_jit = jit(f)
f_jit_lowered = f_jit.lower(x,xi)
print(f_jit_lowered.compile().compiler_ir()[0].to_string())
# PEJIT f(x,xi) using tfc.utils.pejit and print out the compiled code
print("TFC pejit")
f_pejit = pejit(x, xi, constant_arg_nums=[0])(f)
f_pejit_lowered = f_pejit.lower(xi)
print(f_pejit_lowered.compile().compiler_ir()[0].to_string())
JAX jit
HloModule jit__lambda_.12, entry_computation_layout={(f64[3]{0},f64[3]{0})->f64[3]{0}}
ENTRY %main.14 (Arg_0.1: f64[3], Arg_1.2: f64[3]) -> f64[3] {
%constant.3 = s64[] constant(0)
%Arg_0.1 = f64[3]{0} parameter(0)
%constant.4 = pred[] constant(false)
%constant.8 = s64[] constant(3)
%custom-call.1 = f64[3,3]{1,0} custom-call(s64[] %constant.3, f64[3]{0} %Arg_0.1, s64[] %constant.3, pred[] %constant.4, s64[] %constant.8, /*index=5*/s64[] %constant.8), custom_call_target="BasisFunc0", metadata={op_name="custom-call.7"}
%Arg_1.2 = f64[3]{0} parameter(1)
ROOT %dot.13 = f64[3]{0} dot(f64[3,3]{1,0} %custom-call.1, f64[3]{0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_34918/1347355215.py" source_line=13}
}
TFC pejit
HloModule jit__lambda_.13, entry_computation_layout={(f64[3]{0})->f64[3]{0}}
ENTRY %main.4 (Arg_0.1: f64[3]) -> f64[3] {
%constant.2 = f64[3,3]{1,0} constant({ { 1, -1, 1 }, { 1, 0, -1 }, { 1, 1, 1 } })
%Arg_0.1 = f64[3]{0} parameter(0)
ROOT %dot.3 = f64[3]{0} dot(f64[3,3]{1,0} %constant.2, f64[3]{0} %Arg_0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_34918/1347355215.py" source_line=13}
}
The compliled outputs are bit harder to parse than the traces above. However, we can still see the same features: * The TFC pejit
result is shorter than the JAX jit
result. This is because the TFC pejit
result is not computing H(x)
. * The TFC pejit
result has a stored constant (called constant.2
) which is of size 3x3. This is the result of H(x)
.
No need for JIT transforms#
As mentioned above, the inner-functions whose values are cached by pejit
do not need a jit
transform in order to be run through pejit
. This means that we can still “compile” results that utilize these inner-functions. Moreover, these inner-functions can still utilize other JAX transforms. To illustrate, below is a simple function H(x) whose gradient transforms have been implemented, but whose JIT transform has not been defined. f(x,xi)
uses the derivative of H(x)
in its
calculation.
[4]:
import numpy as onp
from jax import core
from jax.interpreters import ad, batching
from tfc.utils import egrad
# Define a simple function with gradient transformations.
H_p = core.Primitive("H")
def H(x, d=0):
return H_p.bind(x, d=d)
# Implicit translation
def H_impl(x, d=0):
if d == 0:
return onp.vstack((x, x**2, x**3, x**4)).T
elif d == 1:
return onp.vstack((onp.ones_like(x), 2*x, 3*x**2, 4*x**3)).T
else:
raise ValueError("Derivatives beyond order 2 have not been implemented yet.")
H_p.def_impl(H_impl)
# Abstract evaluation
def H_abstract_eval(x, d=0):
dim1 = 4
if len(x.shape) == 0:
dims = (dim1,)
else:
dims = (x.shape[0], dim1)
return core.ShapedArray(dims, x.dtype)
H_p.def_abstract_eval(H_abstract_eval)
# Define batching translation
def H_batch(vec, batch, d=0):
return Hjax(*vec, d=d), batch[0]
# Define jacobain vector product
def H_jvp(arg_vals, arg_tans, d=0, full=False):
x = arg_vals[0]
dx = arg_tans[0]
if not (dx is ad.Zero):
if type(dx) is batching.BatchTracer:
flag = onp.any(dx.val != 0)
else:
flag = onp.any(dx != 0)
if flag:
if len(dx.shape) == 1:
out_tans = H(x, d=d + 1) * onp.expand_dims(dx, 1)
else:
out_tans = H(x, d=d + 1) * dx
else:
dim0 = x.shape[0]
dim1 = deg+1
out_tans = np.zeros((dim0, dim1),dtype=x.dtype)
return (H(x, d=d), out_tans)
ad.primitive_jvps[H_p] = H_jvp
# Define f(x,xi) that uses H(x)
x = np.array([1.,2.,3.])
xi = np.array([1.,2.,3.,4.])
dH = egrad(H)
def f(x,xi):
return x + np.dot(dH(x),xi)
f(x,xi)
[4]:
DeviceArray([ 31., 175., 529.], dtype=float64)
If we try and call a regular jit
on f(x,xi), we will get an error related to the JIT transfrom for H(x)
being undefined.
[5]:
f_jit = jit(f)
f_jit(x,xi)
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
/usr/lib/python3.10/runpy.py in _run_module_as_main(***failed resolving arguments***)
190 except _Error as exc:
--> 191 msg = "%s: %s" % (sys.executable, exc)
192 sys.exit(msg)
/usr/lib/python3.10/runpy.py in _run_code(***failed resolving arguments***)
74 loader = mod_spec.loader
---> 75 fname = mod_spec.origin
76 cached = mod_spec.cached
/usr/lib/python3/dist-packages/ipykernel_launcher.py in <module>
11 # This is added back by InteractiveShellApp.init_path()
---> 12 if sys.path[0] == '':
13 del sys.path[0]
/usr/lib/python3/dist-packages/traitlets/config/application.py in launch_instance(***failed resolving arguments***)
843 """
--> 844 app = cls.instance(**kwargs)
845 app.initialize(argv)
/usr/lib/python3/dist-packages/ipykernel/kernelapp.py in start(***failed resolving arguments***)
667 if self.trio_loop:
--> 668 from ipykernel.trio_runner import TrioRunner
669 tr = TrioRunner()
/usr/lib/python3/dist-packages/tornado/platform/asyncio.py in start(***failed resolving arguments***)
194 except (RuntimeError, AssertionError):
--> 195 old_loop = None # type: ignore
196 try:
/usr/lib/python3.10/asyncio/base_events.py in run_forever(***failed resolving arguments***)
593
--> 594 old_agen_hooks = sys.get_asyncgen_hooks()
595 sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
/usr/lib/python3.10/asyncio/base_events.py in _run_once(***failed resolving arguments***)
1859
-> 1860 event_list = self._selector.select(timeout)
1861 self._process_events(event_list)
/usr/lib/python3.10/asyncio/events.py in _run(***failed resolving arguments***)
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in dispatch_queue(***failed resolving arguments***)
460 try:
--> 461 await self.process_one()
462 except Exception:
/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in process_one(***failed resolving arguments***)
446 try:
--> 447 t, dispatch, args = self.msg_queue.get_nowait()
448 except (asyncio.QueueEmpty, QueueEmpty):
/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in dispatch_shell(***failed resolving arguments***)
333 self.shell_stream.flush(zmq.POLLOUT)
--> 334 return
335
/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in execute_request(***failed resolving arguments***)
633 self.log.error("Got bad msg: ")
--> 634 self.log.error("%s", parent)
635 return
/usr/lib/python3/dist-packages/ipykernel/ipkernel.py in do_execute(***failed resolving arguments***)
321
--> 322 if (
323 _asyncio_runner
/usr/lib/python3/dist-packages/ipykernel/zmqshell.py in run_cell(***failed resolving arguments***)
531 self._last_traceback = None
--> 532 return super().run_cell(*args, **kwargs)
533
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_cell(***failed resolving arguments***)
2913 try:
-> 2914 result = self._run_cell(
2915 raw_cell, store_history, silent, shell_futures)
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in _run_cell(***failed resolving arguments***)
2935 coro = self.run_cell_async(
-> 2936 raw_cell,
2937 store_history=store_history,
/usr/lib/python3/dist-packages/IPython/core/async_helpers.py in _pseudo_sync_runner(***failed resolving arguments***)
77 try:
---> 78 coro.send(None)
79 except StopIteration as exc:
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_cell_async(***failed resolving arguments***)
1 # -*- coding: utf-8 -*-
2 """Main IPython class."""
3
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_ast_nodes(***failed resolving arguments***)
3342 "please try to upgrade IPython and open a bug report with your case.")
-> 3343 if _async:
3344 # If interactivity is async the semantics of run_code are
/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_code(***failed resolving arguments***)
3450 if async_ and sys.version_info < (3,8):
-> 3451 last_expr = (await self._async_exec(code_obj, self.user_ns))
3452 code = compile('last_expr', 'fake', "single")
/tmp/ipykernel_34918/1304517386.py in <module>
1 f_jit = jit(f)
----> 2 f_jit(x,xi)
/tmp/ipykernel_34918/2530243949.py in f(***failed resolving arguments***)
65 def f(x,xi):
---> 66 return x + np.dot(dH(x),xi)
67
~/.local/lib/python3.10/site-packages/tfc/utils/TFCUtils.py in wrapped(***failed resolving arguments***)
168 tans = tuple(
--> 169 [onesRobust(args[i]) if i == j else zerosRobust(args[i]) for i in range(len(args))]
170 )
/tmp/ipykernel_34918/2530243949.py in H(***failed resolving arguments***)
9 def H(x, d=0):
---> 10 return H_p.bind(x, d=d)
11
/tmp/ipykernel_34918/2530243949.py in H_jvp(***failed resolving arguments***)
43 if type(dx) is batching.BatchTracer:
---> 44 flag = onp.any(dx.val != 0)
45 else:
/tmp/ipykernel_34918/2530243949.py in H(***failed resolving arguments***)
9 def H(x, d=0):
---> 10 return H_p.bind(x, d=d)
11
JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'H' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
NotImplementedError Traceback (most recent call last)
/tmp/ipykernel_34918/1304517386.py in <module>
1 f_jit = jit(f)
----> 2 f_jit(x,xi)
[... skipping hidden 12 frame]
~/.local/lib/python3.10/site-packages/jax/interpreters/mlir.py in jaxpr_subcomp(ctx, jaxpr, tokens, consts, *args)
988 rule = xla_fallback_lowering(eqn.primitive)
989 else:
--> 990 raise NotImplementedError(
991 f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
992 f"found for platform {ctx.platform}")
NotImplementedError: MLIR translation rule for primitive 'H' not found for platform cpu
However, if we use pejit
and set x
as a compile time constant, the value of H(x)
will be pre-computed, and its result will be cached and used by jit
. Therefore, the transform for H(x)
is never used nor needed.
[6]:
f_pejit = pejit(x,xi,constant_arg_nums=[0])(f)
f_pejit(xi)
[6]:
DeviceArray([ 31., 175., 529.], dtype=float64)
Ways to use pejit#
pejit
can be used in the ways defined above or as a decorator.
[7]:
f_pejit_1 = pejit(x,xi,constant_arg_nums=[0])(f)
@pejit(x,xi,constant_arg_nums=[0])
def f_pejit_2(x,xi):
return x + np.dot(dH(x),xi)
print("Result of f_pejit_1")
print(f_pejit_1(xi))
print("\nResult of f_pejit_2")
print(f_pejit_2(xi))
Result of f_pejit_1
[ 31. 175. 529.]
Result of f_pejit_2
[ 31. 175. 529.]
Reminders#
pejit
removes the argumnent defined inconstant_arg_nums
from the compiled function. Therefore, ifx
off(x,xi)
is included inconstant_arg_nums
, then thef_pejit
will only be a function ofxi
, i.e.,f_pejit(xi)
.pejit
calculates the values associated with theconstant_arg_nums
when it is called. Therefore, ifx
is changed in the Python code afterpejit
is called, the values ofH(x)
cached bypejit
will not be recomputed.