The document discusses modeling stochastic gradient descent (SGD) using stochastic differential equations (SDEs). It outlines SGD, random walks, Wiener processes, and SDEs. It then covers continuous-time SGD and controlled SGD, modeling SGD as an SDE. It provides an example of modeling quadratic loss functions with SGD as an SDE. Finally, it discusses the effects of learning rate and batch size on generalization when modeling SGD as an SDE.
6. Stochastic Gradient Descent (SGD)
• Convergence of SGD
• Assume that the loss function is convex
E[L(x, ¯w) L(x, w⇤
)] o(
1
p
T
) SGD
¯w
Minimum
w⇤
The distance is
guaranteed to be
smallT : step counts
¯w : w after T steps
w⇤
: w at the minimum of L
10. Random Walk
x
t = tt = t
1/2 probability1/2 probability
t = 0
0x
⇢
P(X t = x) = 1
2
P(X t = x) = 1
2
Position of the particle at t is a random variable X t such that
11. Random Walk
t = t
t = 2 t
1/2 1/2
x x0
x x0 2 x2 x
1/4 1/2 1/4
x x0 2 x2 x
1/8 3/8 1/8
t = 3 t
3/8
3 x 3 x
X t
X2 t
X3 t
13. Diffusion
t = 0
t = T
p(x, t = 0) = N(0, 0)
p(x, t = T) = N(0, DT)
Probability density function of Xt : p(x, t)
Di↵usion equation :
@p(x, t)
@t
=
D
2
@2
p(x, t)
@x2
14. Wiener process
A stochastic process W(·) is called a Wiener process if:
(1) W(0) = 0 almost surely,
(2) W(t) W(s) ⇠ N(0, t s) for all t s 0,
(3) W(t1), W(t2) W(t1), ..., W(tn) W(tn 1) are independent random variables.
for all tn > tn 1 > · · · > t2 > t1 > 0
W(t) = Xn t is a Wiener process when t = n t, n ! 1, t ! 0
15. Wiener process
• Random Walk
t = 2 t
x x0 2 x2 x
1/4 1/2 1/4
0
1/8
t = 3 t
x x
x x
0 x x0
1/8 1/8 1/8
1/41/4
X2 t
X3 t X2 t
= X t
16. Outlines
• Stochastic Gradient Descent (SGD)
• Random Walk, Diffusion and Wiener process
• Stochastic Differential Equation (SDE)
• Continuous-time SGD & Controlled SGD
• Effects of SGD on Generalization
18. ⇢ dx(t)
dt = b(x(t)) + B(x(t))dW (t)
dt , where t > 0 and W(t) is a Wiener process
x(0) = x0
Stochastic Differential Equation (SDE)
• Stochastic Differential Equation
x0
x(t)
Trajectory samples of x
Deterministic
part
Stochastic
part
23. Continuous-time SGD & Controlled SGD
• Notation Conventions:
Gradient Descent : xk+1 = xk ⌘rf(xk)
Stochastic Gradient Descent : xk+1 = xk ⌘rf k
(xk)
f : loss function
xk : weights at step k
k : index of training sample at step k (assume batch size is 1)
fi : loss function calculated by batch i, where f(x) = (1/n)⌃n
i=1fi(x)
24. Continuous-time SGD & Controlled SGD
xk+1 xk = ⌘rf k
(xk)
xk+1 xk = ⌘rf(xk) +
p
⌘Vk
Vk =
p
⌘(rf(xk) f k
(xk))
mean of Vk : 0
covariance of Vk : ⌘⌃(xk),
where ⌃(xk) = (1/n)⌃n
i=1(rf(xk) rfi(xk))(rf(xk) rfi(xk))T
Deterministic
part
Stochastic
part minimum
Deterministic
partStochastic
part
27. Continuous-time SGD & Controlled SGD
Xt ⇠ N(x0e 2(1+⌘)t
,
⌘
1 + ⌘
(1 e 4(1+⌘)t
))
t
x
E[Xt] =
⇢
x0, when t = 0
0, when t ! 1
x0
Var[Xt] =
⇢
0, when t = 0
⌘
1+⌘ , when t ! 1
E[Xt⇤ ] =
p
Var[Xt⇤ ]
Fluctuations phase Descent phase r
⌘
1 + ⌘
31. Continuous-time SGD & Controlled SGD
• Optimal control policy
u⇤
t =
⇢
1 if a 0 or t t⇤
, ( t t⇤
is descent phase)
1
1+a(t t⇤) if a > 0 and t > t⇤
, ( t > t⇤
is fluctuations phase)
t
x
Fluctuations
phase
Descent
phase t⇤
a 0 a > 0
f(x) =
1
2
a(x b)2
, assume the covariance of f0
i is ⌘⌃(x)
32. Continuous-time SGD & Controlled SGD
• General Objective Function
f(x) and fi(x) is not necessarily quadratic, and x 2 Rd
assume f(x) ⇡
1
2
dX
i=1
a(i)(x(i) b(i))2
hold locally in x, and
⌃ ⇡ diag{⌃(1), ..., ⌃(d)} where each ⌃(i) is locally constant.
(each dimension is independent)
33. Continuous-time SGD & Controlled SGD
• Controlled SGD Algorithms
At each step k, estimate ak,(i), bk,(i) for
1
2
ak,(i)(xk,(i) bk,(i))2
.
Since rf(i) ⇡ a(i)(x(i) b(i)),
we use linear regression to estimate ak,(i), bk,(i):
1
2
ak,(i)(xk,(i) bk,(i))2
xk,(i)
xk 1,(i)ak,(i) =
gxk,(i) gk,(i)xk,(i)
x2
k,(i) x2
k,(i)
, and bk,(i) = xk,(i)
gk,(i)
ak,(i)
where gk,(i) = rf k
(xk)(i), and gk+1,(i) = k,(i)gk,(i) + (1 k,(i))gk,(i)
Exponentialmoving average
45. Effects of SGD on Generalization
• Theoretical Explanation
d✓ = g(✓)dt +
r
⌘
S
R(✓)dW(t),
where R(✓)R(✓)T
= C(✓) and C(✓) is the covariance of g(✓)
dz = ⇤zdt +
r
⌘
S
p
⇤dW(t)
Change of variables:
z : New variable, where z = V T
(✓ ✓⇤
)
✓⇤
: The parameters at the minimum
V : Orthogonal matrix of the eigen decomposition H = V ⇤V T
H : The Hession of L(✓)