[D] Higher-order corrections to the SGD continuous-time limit
I’ve seen a lot of theoretical studies of SGD that consider it in the limit as the step size goes to zero, which turns it into a stochastic differential equation of the form dx/dt = alpha*(-grad(loss)(x) + noise). It’s a lot easier to compute a lot of useful quantities using this form like e.g. stationary distributions.
However, one of the things that gets lost in this formalism is are intrinsic scales of the problem. In the continuous limit, rescaling the time variable t (or more generally, performing an arbitrary coordinate transformation) leaves the trajectory invariant, because the differential equation is formulated in a covariant fashion. This gives misleading results if you want to analyze something like the convergence rates. In this continuous formulation, you can just rescale your time parameter to ‘speed up’ your dynamics (which is equivalent to increasing alpha), whereas you obviously can’t do this in the discrete formulation, because if you rescale alpha arbitrarily you overshoot and you get bad convergence.
The first thing that came to mind when I started thinking about this was that you could amend your differential equation to include higher-order correction terms. Specifically, if we have a differential equation of the form x'(t) = f(x), we can Taylor expand to get x(t + delta) ≈ x(t) + f(x)*delta + 0.5*Df(x)*f(x)*delta^2 + O(delta^3). This tells us that the difference between the continuous trajectory solution x(t + delta) and the discrete trajectory x(t) + f(x)*delta after a time delta will be roughly 0.5*Df(x)*f(x)*delta^2. In order to get a more accurate model for the discrete-time process x(t+delta) = x(t) + f(x)*delta, we can introduce a correction term into our differential equation: x'(t) = f(x) – 0.5*Df(x)*f(x)*delta. When f is -alpha*grad(loss), this becomes x'(t) = -alpha*grad(loss)(x) + 0.5*alpha*Hessian(loss)(x)*grad(loss)(x)*delta. This correction term breaks covariance: when t is rescaled, both alpha and delta are necessarily rescaled, so the correction term transforms differently than the gradient term. It seems to me like this is a natural way to model the breakdown of covariance in the discrete dynamical system in the continuous setting and to study why certain timescales/learning rates are preferred.
tl;dr: Does know if this version of the continuous-time limit has been studied before? If so, can someone point me towards some references that I can read up on?
submitted by /u/glockenspielcello
[link] [comments]