Finding Optimal Batch Size for ONNX Model
By Dmitri Melikyan | | 1 min read

An example of selecting most efficient inference parameters using a profiler.

This article showcases using an inference profiler on one simple example of choosing most efficient inference batch size. The same process can be applied to other types of inference optimization such as model selection, weight pruning, and many others.

The following code several trial runs with different batch sizes.

import graphsignal
from graphsignal.profilers.onnxruntime import initialize_profiler, profile_inference

# expects `api_key` argument or `GRAPHSIGNAL_API_KEY` environment variable
graphsignal.configure(workload_name='ONNX MNIST inference')

...

for batch_size in (1, 8, 16, 32, 64, 128, 256, 512, 1024):
    test_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    sess_options = onnxruntime.SessionOptions()
    sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

    # Initialize profiler for ONNX inference session.
    initialize_profiler(sess_options)

    session = onnxruntime.InferenceSession(TEST_MODEL_PATH, sess_options)

    test_acc = Accuracy()
    for x, y in test_loader:
        x = x.detach().cpu().numpy().reshape((x.shape[0], 28 * 28))

        # Measure and profile inference batch.
        with profile_inference(session, batch_size=batch_size):
            preds = session.run(None, { 'input': x })

        test_acc.update(torch.tensor(preds[0]), y)

    # Log test accuracy if necessary.
    graphsignal.log_metric('test_acc', test_acc.compute().item())

    # End current trial run and start a new one.
    graphsignal.end_run()

After running this code, it will be possible to identify the run with a batch size leading to the best latency and throughput. Obviously, bigger batch sizes are better, but as expected, the improvement is linear after batch size 256.

ONNX batch size benchmark

To continue optimization process, we can check the inference trace and look for bottlenecks that it's possible to improve.

ONNX batch size trace

To try it out, see Quick Start Guide for instructions.