4. • Amortized inference
f(xn, ψ)New data
Variational parameters
for the new data
Inference network
Similar idea used in Helmholtz machine (Dayan et al.,1995)
7. • Gradients of parameters
∇ξℒ 𝒮(ψ, ξ)
=
N
M ∑
n∈𝒮
∇ξE[ln p(xn ∣ zn, W)] + ∇ξE[ln p(W)] − ∇ξE[ln q(W)]
(6.12)
∇ψ ℒ 𝒮(ψ, ξ)
=
N
M ∑
n∈𝒮
{∇ψ E[ln p(xn ∣ zn, W)] + ∇ψ E[ln p(zn)] − ∇ψ E[ln q(zn)]}
(6.13)
ξ : variational parameter of q(W; ξ)
ψ : inference network parameter of f(xn; ψ)
8. Labelled data 𝒟 𝒜 = {X 𝒜, Y 𝒜}
Un-labelled data 𝒟 𝒰 = X 𝒰
6.1.2.1 M1 model
1. Train encoder and decoder with
2. Train supervised model with
{X 𝒜, X 𝒰}
{Z 𝒜, Y 𝒜}
where is encoded from with the model of 1.Z 𝒜 X 𝒜
6.1.2 Semi-supervised models
9. 6.1.2.2 M2 model
X 𝒜
Y 𝒜 Z 𝒜
W X 𝒰
Z 𝒰Y 𝒰
• Generative process with shared parameter (and shared
prior on and
W
p(Y) p(Z)
p(X 𝒜, X 𝒰, Y 𝒜, Y 𝒰, Z 𝒜, Z 𝒰, W)
= p(X 𝒜 |Y 𝒜, Z 𝒜)p(Y 𝒜)p(Z 𝒜)p(X 𝒰 |Y 𝒰, Z 𝒰)p(Y 𝒰)p(Z 𝒰) (6.14)
10. • Approximate posterior
q(Z 𝒜; X 𝒜, Y 𝒜, ψ) =
∏
n∈𝒜
𝒩(zn |m(xn, yn; ψ), diagm(v(xn, yn; ψ))) (6.15)
q(Z 𝒰; X 𝒰, ψ) =
∏
n∈𝒰
𝒩(zn |m(xn; ψ), diagm(v(xn; ψ)) (6.16)
q(Y 𝒰; X 𝒰, ψ) =
∏
n∈𝒰
Cat(yn |π(xn; ψ)) (6.17)
m, v, π : inference networks parametrized with ψ
q(W; ξ) : Gaussian distribution parametrized with ξ
11. • KL-divergence
DKL[q(Y 𝒰, Z 𝒜, Z 𝒰, W; X 𝒜, Y 𝒜, X 𝒰, ξ, ψ ∥ p(Y 𝒰, Z 𝒜, Z 𝒰, W ∣ X 𝒜, X 𝒰, Y 𝒜)]
= ℱ(ξ, ψ) + const . (6.18)
ℱ(ξ, ψ) = ℒ 𝒜(X 𝒜, Y 𝒜; ξ, ψ) + ℒ 𝒰(X 𝒰; ξ, ψ) − DKL[q(W; ψ) ∥ p(W)] (6.19)
ℒ 𝒜(X 𝒜, Y 𝒜; ξ, ψ)
= E[ln p(X 𝒜 |Y 𝒜, Z 𝒜, W)] + E[ln p(Z 𝒜)] − E[ln q(Z 𝒜; X 𝒜, Y 𝒜, ψ)] (6.20)
ℒ 𝒰(X 𝒰; ξ, ψ) = E[ln p(X 𝒰 |Y 𝒰, Z 𝒰, W)] + E[ln p(Y 𝒰)] + E[ln p(Z 𝒰)]
−E[ln q(Y 𝒰; X 𝒰, ψ)] − E[ln q(Z 𝒰; X 𝒰, ψ)]
(6.21)
• Maximize w.r.t. andℱ(ξ, ψ) ξ ψ
12. • Extension of objective function to use labelled data with a
classification likelihood
ℱβ(ξ, ψ) = ℱ(ξ, ψ) + β ln q(Y 𝒜; X 𝒜, ψ) (6.22)
β : weight of classification likelihood
13. 6.1.3 Applications and extensions
6.1.3.1 Extension of models
• Incorporate recurrent network and attention (DRAW)
• Convolutional VAE
• Disentangle representation learning
• Multi-modal learning with shared latent representation
(e.g., images and texts)
https://jhui.github.io/2017/04/30/DRAW-Deep-recurrent-attentive-writer/
Explanation of DRAW with python implementation:
14. 6.1.3.2 Importance weighted AE
ℒT = Ez(t)∼q(z(t))
[
ln
1
T
T
∑
t=1
p(x, z(t)
)
q(z(t); x) ]
≤ ln Ez(t)∼q(z(t))
[
1
T
T
∑
t=1
p(x, z(t)
)
q(z(t); x)]
= ln Ez(t)∼q(z(t))
[
1
T
T
∑
t=1
p(x|z(t)
)
p(z(t)
)
q(z(t); x) ]
= ln p(x)
(6.23)
• Equivalent to ELBO when T=1
• Larger T is, tighter the bound (appendix A in the paper):
ln p(x) ≥ ⋯ ≥ ℒt+1 ≥ ℒt ≥ ⋯ℒ1 = ℒ (6.24)