Finding Optimal Batch Size for ONNX Model
By Dmitri Melikyan | | 1 min read

An example of selecting most efficient inference parameters with the help of Graphsignal.

This article showcases using Graphsignal on one simple example of choosing most efficient inference batch size. The same process can be applied to other types of inference optimization such as model selection, weight pruning, and many others.

The following code several trial runs with different batch sizes.

import graphsignal

# expects `api_key` argument or `GRAPHSIGNAL_API_KEY` environment variable
graphsignal.configure(deployment='onnx-mnist-local')

...

test_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
test_loader = DataLoader(test_ds, batch_size=args.batch_size)

sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

session = onnxruntime.InferenceSession(TEST_MODEL_PATH, sess_options)

for x, y in test_loader:
    x = x.detach().cpu().numpy().reshape((x.shape[0], 28 * 28))

    # Measure inference batch.
    with graphsignal.start_trace(endpoint='predict', tags=dict(batch_size=args.batch_size)) as trace:
        trace.set_data('input', x)
        preds = session.run(None, { 'input': x })

After running this code for different batch sizes, it will be possible to identify the run with a batch size leading to the best latency and throughput. Obviously, bigger batch sizes are better, but as expected, the improvement is linear after batch size 256.

To continue optimization process, we can check the inference trace and look for bottlenecks that it's possible to improve.

To try it out, see Quick Start Guide for instructions.