JAX Inference Profiling

See the Quick Start Guide on how to install and configure the profiler.

To profile JAX, add the following code around an inference, e.g. prediction. All inferences will be measured, but only a few will be profiled to ensure low overhead. See profiling API reference for full documentation.

Profile JAX using with context manager:

from graphsignal.profilers.jax import profile_inference

with profile_inference():
    # single or batch prediction

Profile using stop:

from graphsignal.profilers.jax import profile_inference

span = profile_inference()
    # single or batch prediction
span.stop()

Examples

The JAX MNIST example illustrates where and how to add the profile_inference method.

Distributed workloads

Graphsignal provides built-in support for distributed inference. Depending on the platform, it may be necessary to provide a run ID to all workers. Refer to Distributed Workloads section for more information.