PEJIT (Partial Evaluation Just In Time)#

This notebook discusses and explains how to use the pejit function. This function works very similary to JAX’s jit function, but pejit can cache variables and inner-functions that depend only on those variables as compile-time constants. This can improve performance if the inner-functions are computationally intensive, and it also means that if the inner-functions are JAX primitives, they do not need JIT transforms in order to be JITed—because their return values are cached, so the inner-functions themselves are never exposed to the JIT.

pejit function#

Positional input arguments#

The positional arguments to pejit should be the same type/size/shape as the positional arguments to the function. pejit uses these internally when tracing.

Note, the result of pejit is applied to a function, so the function that pejit will be applied to should be passed in as an argument to the result of pejit; alternatively, pejit can be used as a decorator. See the examples below for more details.

Optional input keyword arguments#

pejit takes the same optional keword arguments as JAX’s jit. See their page for more details. In addition, pejit take the constant_arg_nums keyword, which is a list of integers that define which position arguments in the function that is being JITed will be treated as compile-time constants: inner-functions that depend only on these constants will also be treated as compile time constants.

Outputs#

The outputs this function is a JITed function whose positional arguments specified by constant_arg_nums have been removed.

Traces of jit vs. pejit#

As an example, suppose we have a function like

\[f(x,\xi) = H(x) \cdot \xi\]

and we want to treat \(x\) as a compile time constant, i.e., if we let \(y = H(x)\) then

\[f_{jit}(\xi) = y \cdot \xi\]

To begin, let’s utilze some of the TFC basis functions for \(H(x)\).

[1]:
import jax.numpy as np
from tfc import utfc

# Define H(x) using Chebyshev orthogonal polynomials
tfc = utfc(3, 0, 2, x0=0.0, xf=1.0)
H = tfc.H
x = tfc.x

# Define an example xi
xi = np.ones(H(x).shape[1])

# Define f(x,xi)
f = lambda x,xi: np.dot(H(x),xi)

Under the hood, pejit uses pe to partially evaluate the function while treating certain arguments as constants. Below, the JAX traces of the original function, and the function after pe is applied are printed out so you can see the difference.

[2]:
from jax import make_jaxpr
from tfc.utils.TFCUtils import pe

jaxpr_f = make_jaxpr(f)
f_pe = pe(x,xi,constant_arg_nums=[0])(f)
jaxpr_f_pe = make_jaxpr(f_pe)

print("Original function")
print(jaxpr_f(x,xi))
print("\nFunction after pe")
print(jaxpr_f_pe(xi))
Original function
{ lambda ; a:f64[3] b:f64[3]. let
    c:f64[3,3] = H[d=0 full=False] a
    d:f64[3] = dot_general[
      dimension_numbers=(((1,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] c b
  in (d,) }

Function after pe
{ lambda a:f64[3,3]; b:f64[3]. let
    c:f64[3] = dot_general[
      dimension_numbers=(((1,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] a b
  in (c,) }

The original function takes in two arguments, x and xi (called a and b in the trace). It has no known constants. It uses x to compute H(x) and stores the result in c. Finally, it takes the dot product between H(x) and xi and stores the result in d, which it returns.

The pe function takes in one argument xi (b in the trace). It has one stored constant H(x) (called a in the trace). It takes the dot product between H(x) and xi and stores the result in c, which it returns. Notice that this trace has pre-computed H(x) and stored the result in a, so x is not needed at runtime and H(x) is not run at run-time.

Let’s take a look at the compiled results of these functions. For the pe function, we will simply use pejit on f, which runs pe before jit. Alternatively, we could have take the function above and passed it through the regular jit and gotten the same result.

[3]:
from jax import jit
from tfc.utils import pejit

# JIT f(x,xi) using the regular jax.jit and print out the compiled code
print("JAX jit")
f_jit = jit(f)
f_jit_lowered = f_jit.lower(x,xi)
print(f_jit_lowered.compile().compiler_ir()[0].to_string())

# PEJIT f(x,xi) using tfc.utils.pejit and print out the compiled code
print("TFC pejit")
f_pejit = pejit(x, xi, constant_arg_nums=[0])(f)
f_pejit_lowered = f_pejit.lower(xi)
print(f_pejit_lowered.compile().compiler_ir()[0].to_string())
JAX jit
HloModule jit__lambda_.12, entry_computation_layout={(f64[3]{0},f64[3]{0})->f64[3]{0}}

ENTRY %main.14 (Arg_0.1: f64[3], Arg_1.2: f64[3]) -> f64[3] {
  %constant.3 = s64[] constant(0)
  %Arg_0.1 = f64[3]{0} parameter(0)
  %constant.4 = pred[] constant(false)
  %constant.8 = s64[] constant(3)
  %custom-call.1 = f64[3,3]{1,0} custom-call(s64[] %constant.3, f64[3]{0} %Arg_0.1, s64[] %constant.3, pred[] %constant.4, s64[] %constant.8, /*index=5*/s64[] %constant.8), custom_call_target="BasisFunc0", metadata={op_name="custom-call.7"}
  %Arg_1.2 = f64[3]{0} parameter(1)
  ROOT %dot.13 = f64[3]{0} dot(f64[3,3]{1,0} %custom-call.1, f64[3]{0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_34918/1347355215.py" source_line=13}
}


TFC pejit
HloModule jit__lambda_.13, entry_computation_layout={(f64[3]{0})->f64[3]{0}}

ENTRY %main.4 (Arg_0.1: f64[3]) -> f64[3] {
  %constant.2 = f64[3,3]{1,0} constant({ { 1, -1, 1 }, { 1, 0, -1 }, { 1, 1, 1 } })
  %Arg_0.1 = f64[3]{0} parameter(0)
  ROOT %dot.3 = f64[3]{0} dot(f64[3,3]{1,0} %constant.2, f64[3]{0} %Arg_0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_34918/1347355215.py" source_line=13}
}


The compliled outputs are bit harder to parse than the traces above. However, we can still see the same features: * The TFC pejit result is shorter than the JAX jit result. This is because the TFC pejit result is not computing H(x). * The TFC pejit result has a stored constant (called constant.2) which is of size 3x3. This is the result of H(x).

No need for JIT transforms#

As mentioned above, the inner-functions whose values are cached by pejit do not need a jit transform in order to be run through pejit. This means that we can still “compile” results that utilize these inner-functions. Moreover, these inner-functions can still utilize other JAX transforms. To illustrate, below is a simple function H(x) whose gradient transforms have been implemented, but whose JIT transform has not been defined. f(x,xi) uses the derivative of H(x) in its calculation.

[4]:
import numpy as onp
from jax import core
from jax.interpreters import ad, batching
from tfc.utils import egrad

# Define a simple function with gradient transformations.
H_p = core.Primitive("H")

def H(x, d=0):
    return H_p.bind(x, d=d)

# Implicit translation
def H_impl(x, d=0):
    if d == 0:
        return onp.vstack((x, x**2, x**3, x**4)).T
    elif d == 1:
        return onp.vstack((onp.ones_like(x), 2*x, 3*x**2, 4*x**3)).T
    else:
        raise ValueError("Derivatives beyond order 2 have not been implemented yet.")

H_p.def_impl(H_impl)

# Abstract evaluation
def H_abstract_eval(x, d=0):
    dim1 = 4
    if len(x.shape) == 0:
        dims = (dim1,)
    else:
        dims = (x.shape[0], dim1)
    return core.ShapedArray(dims, x.dtype)

H_p.def_abstract_eval(H_abstract_eval)

# Define batching translation
def H_batch(vec, batch, d=0):
    return Hjax(*vec, d=d), batch[0]

# Define jacobain vector product
def H_jvp(arg_vals, arg_tans, d=0, full=False):
    x = arg_vals[0]
    dx = arg_tans[0]
    if not (dx is ad.Zero):
        if type(dx) is batching.BatchTracer:
            flag = onp.any(dx.val != 0)
        else:
            flag = onp.any(dx != 0)
        if flag:
            if len(dx.shape) == 1:
                out_tans = H(x, d=d + 1) * onp.expand_dims(dx, 1)
            else:
                out_tans = H(x, d=d + 1) * dx
    else:
        dim0 = x.shape[0]
        dim1 = deg+1
        out_tans = np.zeros((dim0, dim1),dtype=x.dtype)
    return (H(x, d=d), out_tans)

ad.primitive_jvps[H_p] = H_jvp

# Define f(x,xi) that uses H(x)
x = np.array([1.,2.,3.])
xi = np.array([1.,2.,3.,4.])

dH = egrad(H)
def f(x,xi):
    return x + np.dot(dH(x),xi)

f(x,xi)
[4]:
DeviceArray([ 31., 175., 529.], dtype=float64)

If we try and call a regular jit on f(x,xi), we will get an error related to the JIT transfrom for H(x) being undefined.

[5]:
f_jit = jit(f)
f_jit(x,xi)
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
/usr/lib/python3.10/runpy.py in _run_module_as_main(***failed resolving arguments***)
    190     except _Error as exc:
--> 191         msg = "%s: %s" % (sys.executable, exc)
    192         sys.exit(msg)

/usr/lib/python3.10/runpy.py in _run_code(***failed resolving arguments***)
     74         loader = mod_spec.loader
---> 75         fname = mod_spec.origin
     76         cached = mod_spec.cached

/usr/lib/python3/dist-packages/ipykernel_launcher.py in <module>
     11     # This is added back by InteractiveShellApp.init_path()
---> 12     if sys.path[0] == '':
     13         del sys.path[0]

/usr/lib/python3/dist-packages/traitlets/config/application.py in launch_instance(***failed resolving arguments***)
    843         """
--> 844         app = cls.instance(**kwargs)
    845         app.initialize(argv)

/usr/lib/python3/dist-packages/ipykernel/kernelapp.py in start(***failed resolving arguments***)
    667         if self.trio_loop:
--> 668             from ipykernel.trio_runner import TrioRunner
    669             tr = TrioRunner()

/usr/lib/python3/dist-packages/tornado/platform/asyncio.py in start(***failed resolving arguments***)
    194         except (RuntimeError, AssertionError):
--> 195             old_loop = None  # type: ignore
    196         try:

/usr/lib/python3.10/asyncio/base_events.py in run_forever(***failed resolving arguments***)
    593
--> 594         old_agen_hooks = sys.get_asyncgen_hooks()
    595         sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,

/usr/lib/python3.10/asyncio/base_events.py in _run_once(***failed resolving arguments***)
   1859
-> 1860         event_list = self._selector.select(timeout)
   1861         self._process_events(event_list)

/usr/lib/python3.10/asyncio/events.py in _run(***failed resolving arguments***)
     79         try:
---> 80             self._context.run(self._callback, *self._args)
     81         except (SystemExit, KeyboardInterrupt):

/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in dispatch_queue(***failed resolving arguments***)
    460             try:
--> 461                 await self.process_one()
    462             except Exception:

/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in process_one(***failed resolving arguments***)
    446             try:
--> 447                 t, dispatch, args = self.msg_queue.get_nowait()
    448             except (asyncio.QueueEmpty, QueueEmpty):

/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in dispatch_shell(***failed resolving arguments***)
    333             self.shell_stream.flush(zmq.POLLOUT)
--> 334             return
    335

/usr/lib/python3/dist-packages/ipykernel/kernelbase.py in execute_request(***failed resolving arguments***)
    633             self.log.error("Got bad msg: ")
--> 634             self.log.error("%s", parent)
    635             return

/usr/lib/python3/dist-packages/ipykernel/ipkernel.py in do_execute(***failed resolving arguments***)
    321
--> 322             if (
    323                 _asyncio_runner

/usr/lib/python3/dist-packages/ipykernel/zmqshell.py in run_cell(***failed resolving arguments***)
    531         self._last_traceback = None
--> 532         return super().run_cell(*args, **kwargs)
    533

/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_cell(***failed resolving arguments***)
   2913         try:
-> 2914             result = self._run_cell(
   2915                 raw_cell, store_history, silent, shell_futures)

/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in _run_cell(***failed resolving arguments***)
   2935         coro = self.run_cell_async(
-> 2936             raw_cell,
   2937             store_history=store_history,

/usr/lib/python3/dist-packages/IPython/core/async_helpers.py in _pseudo_sync_runner(***failed resolving arguments***)
     77     try:
---> 78         coro.send(None)
     79     except StopIteration as exc:

/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_cell_async(***failed resolving arguments***)
      1 # -*- coding: utf-8 -*-
      2 """Main IPython class."""
      3

/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_ast_nodes(***failed resolving arguments***)
   3342                                  "please try to upgrade IPython and open a bug report with your case.")
-> 3343             if _async:
   3344                 # If interactivity is async the semantics of run_code are

/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py in run_code(***failed resolving arguments***)
   3450                 if async_ and sys.version_info < (3,8):
-> 3451                     last_expr = (await self._async_exec(code_obj, self.user_ns))
   3452                     code = compile('last_expr', 'fake', "single")

/tmp/ipykernel_34918/1304517386.py in <module>
      1 f_jit = jit(f)
----> 2 f_jit(x,xi)

/tmp/ipykernel_34918/2530243949.py in f(***failed resolving arguments***)
     65 def f(x,xi):
---> 66     return x + np.dot(dH(x),xi)
     67

~/.local/lib/python3.10/site-packages/tfc/utils/TFCUtils.py in wrapped(***failed resolving arguments***)
    168         tans = tuple(
--> 169             [onesRobust(args[i]) if i == j else zerosRobust(args[i]) for i in range(len(args))]
    170         )

/tmp/ipykernel_34918/2530243949.py in H(***failed resolving arguments***)
      9 def H(x, d=0):
---> 10     return H_p.bind(x, d=d)
     11

/tmp/ipykernel_34918/2530243949.py in H_jvp(***failed resolving arguments***)
     43         if type(dx) is batching.BatchTracer:
---> 44             flag = onp.any(dx.val != 0)
     45         else:

/tmp/ipykernel_34918/2530243949.py in H(***failed resolving arguments***)
      9 def H(x, d=0):
---> 10     return H_p.bind(x, d=d)
     11

JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'H' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
/tmp/ipykernel_34918/1304517386.py in <module>
      1 f_jit = jit(f)
----> 2 f_jit(x,xi)

    [... skipping hidden 12 frame]

~/.local/lib/python3.10/site-packages/jax/interpreters/mlir.py in jaxpr_subcomp(ctx, jaxpr, tokens, consts, *args)
    988         rule = xla_fallback_lowering(eqn.primitive)
    989       else:
--> 990         raise NotImplementedError(
    991             f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
    992             f"found for platform {ctx.platform}")

NotImplementedError: MLIR translation rule for primitive 'H' not found for platform cpu

However, if we use pejit and set x as a compile time constant, the value of H(x) will be pre-computed, and its result will be cached and used by jit. Therefore, the transform for H(x) is never used nor needed.

[6]:
f_pejit = pejit(x,xi,constant_arg_nums=[0])(f)
f_pejit(xi)
[6]:
DeviceArray([ 31., 175., 529.], dtype=float64)

Ways to use pejit#

pejit can be used in the ways defined above or as a decorator.

[7]:
f_pejit_1 = pejit(x,xi,constant_arg_nums=[0])(f)
@pejit(x,xi,constant_arg_nums=[0])
def f_pejit_2(x,xi):
    return x + np.dot(dH(x),xi)

print("Result of f_pejit_1")
print(f_pejit_1(xi))
print("\nResult of f_pejit_2")
print(f_pejit_2(xi))
Result of f_pejit_1
[ 31. 175. 529.]

Result of f_pejit_2
[ 31. 175. 529.]

Reminders#

  • pejit removes the argumnent defined in constant_arg_nums from the compiled function. Therefore, if x of f(x,xi) is included in constant_arg_nums, then the f_pejit will only be a function of xi, i.e., f_pejit(xi).

  • pejit calculates the values associated with the constant_arg_nums when it is called. Therefore, if x is changed in the Python code after pejit is called, the values of H(x) cached by pejit will not be recomputed.

Further resources#

  • If you’re interested in learning more, the issue where pejit was designed is located here.

  • The issue noted above led to a JAX discussion here, which includes more information on how pejit was designed. A shout out to YouJiacheng for working with me to come up with this solution!