JAX Inference Profiling
See the Quick Start Guide on how to install and configure the profiler.
To profile JAX, add the following code around an inference, e.g. prediction. All inferences will be measured, but only a few will be profiled to ensure low overhead. See profiling API reference for full documentation.
Profile JAX using with
context manager:
from graphsignal.profilers.jax import profile_inference
with profile_inference():
# single or batch prediction
Profile using stop
:
from graphsignal.profilers.jax import profile_inference
span = profile_inference()
# single or batch prediction
span.stop()
Examples
The JAX MNIST example illustrates where and how to add the profile_inference
method.
Distributed workloads
Graphsignal provides built-in support for distributed inference. Depending on the platform, it may be necessary to provide a run ID to all workers. Refer to Distributed Workloads section for more information.