Recently I had an idea that if you correctly scale hyperparameters, there shouldn't be any real
difference between training with different batch sizes. The results should even be better
with a smaller batch size. An analogous experiment would be say smoothing and numerically
integrating a function. Say we need the smoothing because we can do that cheaply but sampling
the function is expensive. When we increase sampling frequency from every 32 steps to every 4
steps, we can't keep multiplying the sample value by 32 - similarly we should also scale down
the learning rate.
Experimental setup
For all the experiments I have been using
imagenette
with the
20 epoch top result,
at the time I started working on the project (there is a different top result now).
At the beginning I experimented with scaling all the optimizer hyperparameters - learning rate,
momentum, squared momentum and weight decay. I didn't scale epsilon, that hopefully shouldn't
make too much difference. Later I also tried scaling batch norm momentum.
For scaling the parameters, a linear scaling factor would be CURRENT_BATCH_SIZE / ORIGINAL_BATCH_SIZE.
Lr and wd would be directly multiplied by it, with momentum the time constant (i.e. 1 / ln(mom))
would be directly divided (giving a higher momentum for lower batch size). I also tried scaling
factors offset by 1/4, 1/2, 2, 4 to see if the scaling might be non linear.
Results
Sadly when I ran all the experiments it didn't go exactly as predicted. Interpreting the results
was also slightly complicated by the baseline not being the optimal training strategy, even if
I account for random errors.
The most important turned out to be scaling learning rate, though
at a slightly slower rate than the batch size is scaled.
Weight decay was useful to increase when batch size was large, but decreasing it
wasn't with small batch size wasn't good.
Scaling squared momentum didn't seem to really have any effect, scaling momentum seemed
to make it sligly worse, but that is somewhat within error.
Scaling batch norm momentum made it slightly worse