Hyperparameter scaling with batch size

20 September 2020
Recently I had an idea that if you correctly scale hyperparameters, there shouldn't be any real difference between training with different batch sizes. The results should even be better with a smaller batch size. An analogous experiment would be say smoothing and numerically integrating a function. Say we need the smoothing because we can do that cheaply but sampling the function is expensive. When we increase sampling frequency from every 32 steps to every 4 steps, we can't keep multiplying the sample value by 32 - similarly we should also scale down the learning rate.

Experimental setup

For all the experiments I have been using imagenette with the 20 epoch top result, at the time I started working on the project (there is a different top result now).

At the beginning I experimented with scaling all the optimizer hyperparameters - learning rate, momentum, squared momentum and weight decay. I didn't scale epsilon, that hopefully shouldn't make too much difference. Later I also tried scaling batch norm momentum.

For scaling the parameters, a linear scaling factor would be CURRENT_BATCH_SIZE / ORIGINAL_BATCH_SIZE. Lr and wd would be directly multiplied by it, with momentum the time constant (i.e. 1 / ln(mom)) would be directly divided (giving a higher momentum for lower batch size). I also tried scaling factors offset by 1/4, 1/2, 2, 4 to see if the scaling might be non linear.

Results

Sadly when I ran all the experiments it didn't go exactly as predicted. Interpreting the results was also slightly complicated by the baseline not being the optimal training strategy, even if I account for random errors.
  • The most important turned out to be scaling learning rate, though at a slightly slower rate than the batch size is scaled.
  • Weight decay was useful to increase when batch size was large, but decreasing it wasn't with small batch size wasn't good.
  • Scaling squared momentum didn't seem to really have any effect, scaling momentum seemed to make it sligly worse, but that is somewhat within error.
  • Scaling batch norm momentum made it slightly worse


Feel free to have a look at the full raw results.