4. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n FV b F caF iumw
- C 1 8 C / 1 8
iuf
n t pr m LI F So Nyf L nS
lRL h de
/8 1AA 1
/8 1 1 7
1A . 1 7
skg 01 CA 1 1 7 1 1 8 A 8 21 2 : C 2 7 1 7 1
5. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n F N I F SF
- C 1 8 C / 1 8
RIL V
n / -
/8 1AA 1
/8 1 1 7
1A . 1 7
01 CA 1 1 7 1 1 8 A 8 21 2 : C 2 7 1 7 1
6. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Learning to Learn
● What is Meta Learning / Learning to Learn?
○ Go beyond train/test from same distribution.
○ Task between train/test changes, so model has to “learn to learn”
● Datasets
132
Lake et al,
2013, 2015
I/O
7. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
8. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
9. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
10. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
11. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n -
11
T’: Meta testing taskT: Meta training task
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
• - :
• -
https://www.cs.toronto.edu/~kriz/cifar.html
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
12. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n - .
12
T’: Meta testing taskT: Meta training task
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
• T
T :
• M a
N
https://www.cs.toronto.edu/~kriz/cifar.html
• T’
:
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
13. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n - - 3 ,
13
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
L’: Label set
, 3
• :
- 3 ,
automobile
cat
deer
https://www.cs.toronto.edu/~kriz/cifar.html
T’: Meta testing taskT: Meta training task
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
14. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n -
14
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
L’: Label set
S’: Support set : Query
automobile
cat
deer
1- - -
1- 1- 1- 1-
ˆx
https://www.cs.toronto.edu/~kriz/cifar.html
T’: Meta testing taskT: Meta training task
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
15. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n - - 1 3
15
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
L’: Label set
S’: Support set : Query
automobile
cat
deer
31
- 31 31
ˆx
• T
LN
:!"
https://www.cs.toronto.edu/~kriz/cifar.html
T’: Meta testing taskT: Meta training task
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
16. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
17. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Losses
Learning to Learn
Model Based
● Santoro et al. ’16
● Duan et al. ’17
● Wang et al. ‘17
● Munkhdalai & Yu ‘17
● Mishra et al. ‘17
Optimization Based
● Schmidhuber ’87, ’92
● Bengio et al. ’90, ‘92
● Hochreiter et al. ’01
● Li & Malik ‘16
● Andrychowicz et al. ’16
● Ravi & Larochelle ‘17
● Finn et al ‘17
Metric Based
● Koch ’15
● Vinyals et al. ‘16
● Snell et al. ‘17
● Shyam et al. ‘17
Trends: Learning to Learn / Meta Learning 135
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
18. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
, ( 2 - 16 - ,(--
1 ),+ 0
n z il Ue gL lO Ue g
ilPSo
0 T s e g Ma Rz cd
m
• z m V N ISI cd M 0
) . 1 ).1 p t
z LilO U cd m PRyu PSe g
v
Oriol Vinyals, NIPS 17
Model Based Meta Learning
rn 2 : ) 7 ( ) 7 : A:
19. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 1 6 2 1 + 1
, +
n V sI Or zVih V O
e msL 1 1 o
V V h E
!", $" v L
2 dg aS !", $" S Oh V t m
w FPO ih NV uI
2 dg aS !", $" S Oh V t m
ie n MOlSw e VpFPOih uL
yu 0 12 : 7 0 2 0 20 -. 2 0 20 . 7: 7
Oriol Vinyals, NIPS 17
Matching Networks, Vinyals et al, NIPS 2016
Metric Based Meta Learning
20. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 1 0 0 7 72
,+70 + -
n k
M L /-/ F N
LF / - /:!" i
-/ - h F e
•
•
Zc g f a
•
2.1 Notation
In few-shot classification we are given a small support set of N labeled examples S =
{(x1, y1), . . . , (xN , yN )} where each xi 2 RD
is the D-dimensional feature vector of an example
and yi 2 {1, . . . , K} is the corresponding label. Sk denotes the set of examples labeled with class k.
2.2 Model
Prototypical Networks compute an M-dimensional representation ck 2 RM
, or prototype, of each
class through an embedding function f : RD
! RM
with learnable parameters . Each prototype
is the mean vector of the embedded support points belonging to its class:
ck =
1
|Sk|
X
(xi,yi)2Sk
f (xi) (1)
Given a distance function d : RM
⇥ RM
! [0, +1), Prototypical Networks produce a distribution
over classes for a query point x based on a softmax over distances to the prototypes in the embedding
space:
p (y = k | x) =
exp( d(f (x), ck))
P
k0 exp( d(f (x), ck0 ))
(2)
Learning proceeds by minimizing the negative log-probability J( ) = log p (y = k | x) of the
true class k via SGD. Training episodes are formed by randomly selecting a subset of classes from
the training set, then choosing a subset of examples within each class to act as the support set and a
2
ing
ers from few-shot learning in that instead of being given a support set of
given a class meta-data vector vk for each class. These could be determined
d be learned from e.g., raw text [8]. Modifying Prototypical Networks to deal
s straightforward: we simply define ck = g#(vk) to be a separate embedding
. An illustration of the zero-shot procedure for Prototypical Networks as
ot procedure is shown in Figure 1. Since the meta-data vector and query
ent input domains, we found it was helpful empirically to fix the prototype
it length, however we do not constrain the query embedding f.
we performed experiments on Omniglot [18] and the miniImageNet version
with the splits proposed by Ravi and Larochelle [24]. We perform zero-shot
1 version of the Caltech UCSD bird dataset (CUB-200 2011) [34].
ot Classification
et of 1623 handwritten characters collected from 50 alphabets. There are 20
Notation
ew-shot classification we are given a small support set of N labeled examples S =
, y1), . . . , (xN , yN )} where each xi 2 RD
is the D-dimensional feature vector of an example
yi 2 {1, . . . , K} is the corresponding label. Sk denotes the set of examples labeled with class k.
Model
otypical Networks compute an M-dimensional representation ck 2 RM
, or prototype, of each
through an embedding function f : RD
! RM
with learnable parameters . Each prototype
mean vector of the embedded support points belonging to its class:
ck =
1
|Sk|
X
(xi,yi)2Sk
f (xi) (1)
n a distance function d : RM
⇥ RM
! [0, +1), Prototypical Networks produce a distribution
classes for a query point x based on a softmax over distances to the prototypes in the embedding
e:
p (y = k | x) =
exp( d(f (x), ck))
P
k0 exp( d(f (x), ck0 ))
(2)
ning proceeds by minimizing the negative log-probability J( ) = log p (y = k | x) of the
class k via SGD. Training episodes are formed by randomly selecting a subset of classes from
aining set, then choosing a subset of examples within each class to act as the support set and a
2
21. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 1 0 0 7 72
,+70 + -
n a
i -
• O > ,
• - I h aT AN E
i c e M d F> SgC /
i / > SgC -
Table 1: Few-shot classification accuracies on Omniglot. ⇤
Uses non-standard train/test splits.
5-way Acc. 20-way Acc.
Model Dist. Fine Tune 1-shot 5-shot 1-shot 5-shot
MATCHING NETWORKS [32] Cosine N 98.1% 98.9% 93.8% 98.5%
MATCHING NETWORKS [32] Cosine Y 97.9% 98.7% 93.5% 98.7%
NEURAL STATISTICIAN [7] - N 98.1% 99.5% 93.2% 98.1%
MAML [9]⇤
- N 98.7% 99.9% 95.8% 98.9%
PROTOTYPICAL NETWORKS (OURS) Euclid. N 98.8% 99.7% 96.0% 98.9%
Table 2: Few-shot classification accuracies on miniImageNet. All accuracy results are averaged over
600 test episodes and are reported with 95% confidence intervals. ⇤
Results reported by [24].
5-way Acc.
Model Dist. Fine Tune 1-shot 5-shot
BASELINE NEAREST NEIGHBORS
⇤
Cosine N 28.86 ± 0.54% 49.79 ± 0.79%
MATCHING NETWORKS [32]⇤
Cosine N 43.40 ± 0.78% 51.09 ± 0.71%
MATCHING NETWORKS FCE [32]⇤
Cosine N 43.56 ± 0.84% 55.31 ± 0.73%
META-LEARNER LSTM [24]⇤
- N 43.44 ± 0.77% 60.60 ± 0.71%
MAML [9] - N 48.70 ± 1.84% 63.15 ± 0.91%
PROTOTYPICAL NETWORKS (OURS) Euclid. N 49.42 ± 0.78% 68.20 ± 0.66%
3.2 miniImageNet Few-shot Classification
The miniImageNet dataset, originally proposed by Vinyals et al. [32], is derived from the larger
Figure 3: Comparison showing the effect of distance metric and number of classes per training
episode on 5-way classification accuracy for both Matching Networks and Prototypical Networks
on miniImageNet. The x-axis indicates configuration of the training episodes (way, distance, and
shot), and the y-axis indicates 5-way test accuracy for the corresponding shot. Error bars indicate
95% confidence intervals as computed over 600 test episodes. Note that Matching Networks and
Prototypical Networks are identical in the 1-shot case.
Table 3: Zero-shot classification accuracies on CUB-200.
Model
Image
Features
50-way Acc.
0-shot
ALE [1] Fisher 26.9%
SJE [2] AlexNet 40.3%
SAMPLE CLUSTERING [19] AlexNet 44.3%
SJE [2] GoogLeNet 50.1%
DS-SJE [25] GoogLeNet 50.4%
DA-SJE [25] GoogLeNet 50.9%
SYNTHESIZED CLASSIFIERS [6] GoogLeNet 54.7%
PROTOTYPICAL NETWORKS (OURS) GoogLeNet 54.8%
22. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n uS D O yIl s m i yOG
F CL M I y
P 0 Sz pS g td OT D D
LM yr yO
• θ" = $" ⊙ θ"&' − )" ⊙ ∇+,-.
ℒ"
θ" = 1 ∗ θ"&' − η ∗ ∇+,-.
ℒ"
0 3" = $" ⊙ 3"&' − )" ⊙ 43"
• $", )" ∇+,-.
ℒ", ℒ", θ"&' nMNoa
θ6V 2 7 2 S D PIN y
+ 7 01 2 1 1
-, ,
Thus, we propose training a meta-learner LSTM to learn an update rule for traini
work. We set the cell state of the LSTM to be the parameters of the learner, or c
candidate cell state ˜ct = r✓t 1 Lt, given how valuable information about the grad
mization. We define parametric forms for it and ft so that the meta-learner can de
values through the course of the updates.
Let us start with it, which corresponds to the learning rate for the updates. We let
it = WI ·
⇥
r✓t 1
Lt, Lt, ✓t 1, it 1
⇤
+ bI ,
meaning that the learning rate is a function of the current parameter value ✓t 1, the
r✓t 1
Lt, the current loss Lt, and the previous learning rate it 1. With this inform
learner should be able to finely control the learning rate so as to train the learne
avoiding divergence.
As for ft, it seems possible that the optimal choice isn’t the constant 1. Intuitive
justify shrinking the parameters of the learner and forgetting part of its previous
if the learner is currently in a bad local optima and needs a large change to esca
correspond to a situation where the loss is high but the gradient is close to zero. Th
for the forget gate is to have it be a function of that information, as well as the previ
forget gate:
ft = WF ·
⇥
r✓t 1 Lt, Lt, ✓t 1, ft 1
⇤
+ bF .
Additionally, notice that we can also learn the initial value of the cell state c0 for the
it as a parameter of the meta-learner. This corresponds to the initial weights of th
3
cover classes not present in any of the datasets in Dmeta train (similarly, we ad
meta-validation set that is used to determine hyper-parameters).
where ✓t 1 are the parameters of the learner after t 1 updates, ↵t is the learn
Lt is the loss optimized by the learner for its tth
update, r✓t 1 Lt is the gradien
respect to parameters ✓t 1, and ✓t is the updated parameters of the learner.
Our key observation that we leverage here is that this update resembles the updat
in an LSTM (Hochreiter & Schmidhuber, 1997)
ct = ft ct 1 + it ˜ct,
if ft = 1, ct 1 = ✓t 1, it = ↵t, and ˜ct = r✓t 1
Lt.
Thus, we propose training a meta-learner LSTM to learn an update rule for train
work. We set the cell state of the LSTM to be the parameters of the learner, or
candidate cell state ˜ct = r✓t 1 Lt, given how valuable information about the gr
mization. We define parametric forms for it and ft so that the meta-learner can d
values through the course of the updates.
Let us start with it, which corresponds to the learning rate for the updates. We let
it = WI ·
⇥
r✓t 1
Lt, Lt, ✓t 1, it 1
⇤
+ bI ,
meaning that the learning rate is a function of the current parameter value ✓t 1, th
r✓t 1
Lt, the current loss Lt, and the previous learning rate it 1. With this infor
learner should be able to finely control the learning rate so as to train the learn
avoiding divergence.
As for ft, it seems possible that the optimal choice isn’t the constant 1. Intuiti
justify shrinking the parameters of the learner and forgetting part of its previou
if the learner is currently in a bad local optima and needs a large change to esc
correspond to a situation where the loss is high but the gradient is close to zero. T
for the forget gate is to have it be a function of that information, as well as the pre
forget gate:
ft = WF ·
⇥
r✓t 1 Lt, Lt, ✓t 1, ft 1
⇤
+ bF .
Additionally, notice that we can also learn the initial value of the cell state c0 for th
it as a parameter of the meta-learner. This corresponds to the initial weights of t
3
Oriol Vinyals, NIPS 17
Figure Credit: Hugo Larochelle
Optimization Based Meta Learning
v e - :7 1 27 : 7 - 2 : 2 2 , . 2 2 :
23. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
(7 2 - 2 7 +2 ( 2 2
A ( - 0+ ,) - 1
n g r e t r
DF C I r
- r m
,i a n r p
N I , I
a n r
Model-Agnostic Meta-Learning for Fast Adaptatio
Algorithm 1 Model-Agnostic Meta-Learning
Require: p(T ): distribution over tasks
Require: ↵, : step size hyperparameters
1: randomly initialize ✓
2: while not done do
3: Sample batch of tasks Ti ⇠ p(T )
4: for all Ti do
5: Evaluate r✓LTi
(f✓) with respect to K examples
6: Compute adapted parameters with gradient de-
scent: ✓0
i = ✓ ↵r✓LTi
(f✓)
7: end for
8: Update ✓ ✓ r✓
P
Ti⇠p(T ) LTi (f✓0
i
)
9: end while
products, wh
braries such
experiments,
this backwar
which we di
3. Species
In this secti
meta-learnin
forcement le
function and
sented to the
nism can be
meta-learning
learning/adaptation
✓
rL1
rL2
rL3
✓⇤
1 ✓⇤
2
✓⇤
3
Figure 1. Diagram of our model-agnostic meta-learning algo-
rithm (MAML), which optimizes for a representation ✓ that can
quickly adapt to new tasks.
Oriol Vinyals, NIPS 17
Summing Up
Model Based
Metric Based
Optimization Based
Oriol Vinyals, NIPS 17
Examples of Optimization Based Meta Learning
Finn et al, 17
Ravi et al, 17
24. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n
3 3
• 3
•
3
3
Hidden layers Hidden layers
Temporal Dropout
Neighborhood Attention
+
Temporal Convolution
Attention over
Demonstration
Demonstration Current State
Action
ABlock# B C D E F G H I J
Attention
over
Current
State
Context Network
Demonstration Network
Manipulation Network
Context Embedding
placed on the table.
25. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n 13 3 :3
n hp iohp A u
• e v N m rw d
• e d N C 1 **
1 3 3 3
• ls d g dbk T
3 33* 3
• D v N t N a 3
1 3 3 3 3
• dbk a * 3 3 3
Hidden layers Hidden layers
Temporal Dropout
Neighborhood Attention
+
Temporal Convolution
Attention over
Demonstration
Demonstration Current State
Action
ABlock# B C D E F G H I J
Attention
over
Current
State
Context Network
Demonstration Network
Manipulation Network
Context Embedding
One-Shot Imitation Learning
h1 h2 h3 hN
N Attentions
i-th attention is applied
to i-th block vs others
Figure 4. Illustration of the neighborhood attention operation. It
receives a list of embeddings for each block, performs one at-
formance of the policy.
After downsampling the demonstration, we apply a se-
quence of operations, composed of dilated temporal con-
volution (Yu & Koltun, 2016) and neighborhood attention.
4.2.3. CONTEXT NETWORK
The context network is the crux of our model. Illustrated
in Fig. 6, it processes both the current state and the embed-
ding produced by the demonstration network, and outputs
26. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n ,: ,
wCn u e N L S am
• b , , n ,
• S , ,: MS
, , ,
• d h n ,C , , C , kn
• , , s i
, , , ,
• v T M , , c An ,
xg , r o
• vC S ,C
At
Hidden layers Hidden layers
Temporal Dropout
Neighborhood Attention
+
Temporal Convolution
Attention over
Demonstration
Demonstration Current State
Action
ABlock# B C D E F G H I J
Attention
over
Current
State
Context Network
Demonstration Network
Manipulation Network
Context Embedding
27. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n , ,
: M NP
• L ,
• L M
Hidden layers Hidden layers
Temporal Dropout
Neighborhood Attention
+
Temporal Convolution
Attention over
Demonstration
Demonstration Current State
Action
ABlock# B C D E F G H I J
Attention
over
Current
State
Context Network
Demonstration Network
Manipulation Network
Context Embedding
Figure 2: Illustration of the network architecture.
In our experiments, we use p = 0.95, which reduces the length of demonstrations by a factor of 20.
During test time, we can sample multiple downsampled trajectories, use each of them to compute
downstream results, and average these results to produce an ensemble estimate. In our experience,
this consistently improves the performance of the policy.
Neighborhood Attention: After downsampling the demonstration, we apply a sequence of opera-
tions, composed of dilated temporal convolution [65] and neighborhood attention. We now describe
this second operation in more detail.
From the figure, we can observe that for the easier tasks with fewer stages, all of the different
conditioning strategies perform equally well and almost perfectly. As the difficulty (number of stages)
increases, however, conditioning on the entire demonstration starts to outperform conditioning on the
final state. One possible explanation is that when conditioned only on the final state, the policy may
struggle about which block it should stack first, a piece of information that is readily accessible from
demonstration, which not only communicates the task, but also provides valuable information to help
accomplish it.
More surprisingly, conditioning on the entire demonstration also seems to outperform conditioning
on the snapshot, which we originally expected to perform the best. We suspect that this is due
to the regularization effect introduced by temporal dropout, which effectively augments the set of
demonstrations seen by the policy during training.
Another interesting finding was that training with behavioral cloning has the same level of performance
as training with DAGGER, which suggests that the entire training procedure could work without
requiring interactive supervision. In our preliminary experiments, we found that injecting noise into
the trajectory collection process was important for behavioral cloning to work well, hence in all
experiments reported here we use noise injection. In practice, such noise can come from natural
human-induced noise through tele-operation, or by artificially injecting additional noise before
applying it on the physical robot.
5.2 Visualization
We visualize the attention mechanisms underlying the main policy architecture to have a better
understanding about how it operates. There are two kinds of attention we are mainly interested in,
one where the policy attends to different time steps in the demonstration, and the other where the
policy attends to different blocks in the current state. Fig. 4 shows some of the attention heatmaps.
(a) Attention over blocks in the current state. (b) Attention over downsampled demonstration.
Figure 4: Visualizing attentions performed by the policy during an entire execution. The task
28. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n a”
+,
• ph p [
2 1 : 1 : 5
• s o [ nvn R ]
: 0
• nvn a A R R [E
SR aer
1 1
• nv R a A R a
1 :
• S M i F nvn a A
R a
i lga D G
o S M i nvn G D
O RA S” StcbR D
29. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n l
G B R n EF eiD
p B C R
R / 1 /
R hA
E G G A
• e a S D so
tG
per task for training, and maintain a separate set of trajectories and initial configurations to be used
for evaluation. The trajectories are collected using a hard-coded policy.
5.1 Performance Evaluation
1 2 3 4 5 6 7
Number of Stages
0%
20%
40%
60%
80%
100%
AverageSuccessRate
Policy Type
Demo
BC
DAGGER
Snapshot
Final state
(a) Performance on training tasks.
2 4 5 6 7 8
Number of Stages
0%
20%
40%
60%
80%
100%
AverageSuccessRate
Policy Type
Demo
BC
DAGGER
Snapshot
Final state
(b) Performance on test tasks.
Figure 3: Comparison of different conditioning strategies. The darkest bar shows the performance of the
hard-coded policy, which unsurprisingly performs the best most of the time. For architectures that use temporal
dropout, we use an ensemble of 10 different downsampled demonstrations and average the action distributions.
Then for all architectures we use the greedy action for evaluation.
Fig. 3 shows the performance of various architectures. Results for training and test tasks are presented
separately, where we group tasks by the number of stages required to complete them. This is because
tasks that require more stages to complete are typically more challenging. In fact, even our scripted
policy frequently fails on the hardest tasks. We measure success rate per task by executing the greedy
policy (taking the most confident action at every time step) in 100 different configurations, each
conditioned on a different demonstration unseen during training. We report the average success rate
over all tasks within the same group.
30. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1
7 + 2 , 0 +-
n a c bO n s Mb k pl BM
L It iu M
y dg M . : : : . /
• .: : / .: :. : re : / .: : .: h M
: :. a c bO S
• .: : / .: :. : re :. : .: : _
bO o
y iu
Planning Module
Planning with Visual Foresight [4,5]
User Input/Task Specification
How can we enable robots to learn vision-based
manipulation skills that generalize to new objects & goals?
One-Shot Visual Imitation Learning Planning with Visual Foresight
Can robots reuse data from other tasks to adapt to
new objects from only one visual demonstration?
Our meta-learning approach: Learn to learn many other
tasks using one demo
training sets test sets
meta-training
tasks
{
task 1
task 2
…
…
Meta-testing:
Learn new held-out
task from 1 demo
How can robots acquire general models and skills
using entirely autonomously-collected data?
Meta-Imitation Learning using MAML [1,2]
meta-training time:
- Learn from raw pixel observations
(rather than task-specific, engineered representations)
- collect data with a diverse range of objects and environments
- reuse data from other objects & tasks when learning to
perform new task
Collect data autonomously
Predict future video for different actions [3,5]
Demo: Robot placing, tasks correspond to different objects.
- program initial motions, provide objects
- record camera images and robot actions
- no object supervision, camera calibration, human
annotation, etc.
Frederik Ebert, Chelsea Finn, Alex Lee, Sergey LevineChelsea Finn*, Tianhe Yu*, Tianhao Zhang, Pieter Abbeel, Sergey Levine
Meta-training:
Meta-training
tasks:
Standard robotics paradigm:
Brittle, hand-engineered pipeline.
RGB-D image
segment objects
estimate pose & physics of segments
optimize action using
estimated poses & physics
Our approach:
input image future predictions
actions:
[1] Finn, Abbeel, Levine. Model-Agnostic Meta-
[2] Finn*, Yu*, Zhang, Abbeel, Levine. One-Shot
Pla
D
G
Use
new objects from only one visual demonstration?
Our meta-learning approach: Learn to learn many other
tasks using one demo
training sets test sets
meta-training
tasks
{
task 1
task 2
…
…
Meta-testing:
Learn new held-out
task from 1 demo
u
Meta-Imitation Learning using MAML [1,2]
val demo
meta-training time:
training demo
meta-training
tasks
meta-test time:
demo of meta-test task
with held-out objects
Co
Pre
Sam
1.
2.
3.
4.
policy architecture
shown in demo
Demo: Robot placing, tasks correspond to different objects.
-
-
-
Meta-training:
One-Shot Imitation
Learning Research
Meta-training
tasks: inp
mh :: .. . . . .- -. :
31. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ 2 - 1 7 1 21 1 2 1 21
, 0
n Meta-learning loss:
n Task loss = behavioral cloning loss: [Pomerleau’89,Sammut’92]
Learning a One-Shot Imitator with MAML
[Finn*, Yu*, Zhang, Abbeel, Levine, 2017] Pieter Abbeel -- embody.ai / UC Berkeley / Gradescope
C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7
32. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ 2 - 1 7 1 21 1 2 1 21
, 0
C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7
n Meta-training targets / objects
Robot Experiments: Learning to Place
n Meta-testing targets / objects
1,300 demonstrations for meta-training
Pieter Abbeel -- embody.ai / UC Berkeley / Gradescope[Finn*, Yu*, Zhang, Abbeel, Levine, 2017]
33. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ 2 - 1 7 1 21 1 2 1 21
, 0
C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7
Robot Experiments: Learning to Place
1 demo imitation
Succes rate: 90% Pieter Abbeel -- embody.ai / UC Berkeley / Gradescope[Finn*, Yu*, Zhang, Abbeel, Levine, 2017]
35. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n p y r
00 100 w n N
hc i Ld l
n d l t sRu a
k j Nov LGbge N L L G6 4
d l t
-66 /62 7G. 2 D C
64 3 7 236:2 2 2 6
65 2
:2 65 2 7
37. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
( ( ) )
Message Passing Neural Networks
144Trends: Graph Networks
INPUT:
● Undirected graph G
● node features h_v (vector)
● edge features e_vw (vector)
MPNN
OUTPUT:
● Graph level target
● Family of neural networks which commute with order of vertices.
● Generalizes a few existing NN based methods.
● Related to the Weisfeiler-Lehman algorithm for GI testing.
Arch
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
38. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n , 2 2 7 2 , 0)7 : (,+ 1
R N) 2 ( wM]r )7 : [ o
GP e L
• , 2 2 7 2
+ M
• i h sL k gm ]
• M]i h a d k gm uM]
n + M C i h k gm + RI]
• 2 2
, 2 2 7 R u Ni h p L
l sL oP k gm tp ]
ntum Chemistry
Oriol Vinyals 3
George E. Dahl 1
DFT
103
seconds
Message Passing Neural Net
10 2
seconds
E,!0, ...
Targets
)7 : (,+
39. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
Message Function:
!"(ℎ%
" , ℎ'(
" , )%'(
)
Σ
Message Function:
!"(ℎ%
" , ℎ'(
" , )%'(
)
Neural Message Passing for
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function M , vertex update function U , and readout func-
n + : , : : 1 C 5 2
l icN ]E S eF[Uh
• , :
l , : !" ℎ%
" , ℎ'
" , )%' = ,-.,/0(ℎ'
" , )%')
l 0 5 : 1" ℎ%
" , 2%
"34 = 5(6"
789(%)
2%
"34)
• 6"
789(%)
0D : N ISda deg(:) MNfg L P
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
Neural Message Passing for Quantum Chemistry
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function Mt, vertex update function Ut, and readout func-
tion R used. Note one could also learn edge features in
an MPNN by introducing hidden states for all edges in the
graph ht
evw
and updating them analogously to equations 1
and 2. Of the existing MPNNs, only Kearnes et al. (2016)
Recurrent Unit introduced in Cho et al. (2014). This work
used weight tying, so the same update function is used at
each time step t. Finally,
R =
X
v2V
⇣
i(h(T )
v , h0
v)
⌘ ⇣
j(h(T )
v )
⌘
(4)
where i and j are neural networks, and denotes element-
wise multiplication.
Interaction Networks, Battaglia et al. (2016)
This work considered both the case where there is a tar-
get at each node in the graph, and where there is a graph
level target. It also considered the case where there are
node level effects applied at each time step, in such a
case the update function takes as input the concatenation
(hv, xv, mv) where xv is an external vector representing
some outside influence on the vertex v. The message func-
tion M(hv, hw, evw) is a neural network which takes the
concatenation (hv, hw, evw). The vertex update function
U(hv, xv, mv) is a neural network which takes as input
the concatenation (hv, xv, mv). Finally, in the case where
there is a graph level output, R = f(
P
v2G
hT
v ) where f is
a neural network which takes the sum of the final hidden
states hT
v . Note the original work only defined the model
for T = 1.
Molecular Graph Convolutions, Kearnes et al. (2016)
)%'(
)%'>
Update Function:
1"(ℎ%
" , 2%
"34)
40. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n ( 5 : D 5 + : C )DE 5D ,02
S R [ ] LI M N (+0 P a
• 1 5 DC 5
1 5 DC D C ! ℎ#
$
% ∈ ' = )( ∑#,- ./)0123 4-ℎ#
-
)
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
1 5 DC D C )( ∑#,- ./)0123 4-ℎ#
-
)
ℎ#
(6)
ℎ78
(6)
ℎ79
(6)
FFFFFF
ℎ78
($)
ℎ79
(:)
ℎ#
($)
5: 5 :
5: 5 :
;#78
;#79
<= = !({ℎ#
($)
|% ∈ '})
41. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n , , 0 0 F C ,, 00 . - .1 (
2 + 6 M ,12U] h iU g
• CC : CC : C
CC : 6 ) !" ℎ$
" , ℎ&
" , '$& = )*&ℎ&
"
• )*&) Rde fa G a G 6 I [cL N
2 6 ) +" ℎ$
" , ,$
"-. = /0+ ℎ$
" , ,$
"-.
Message Function:
!"(ℎ$
" , ℎ&2
" , '$&2
)
Σ
Message Function:
!"(ℎ$
" , ℎ&2
" , '$&2
)
Neural Message Passing for
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function M , vertex update function U , and readout func-
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
Neural Message Passing for Quantum Chemistry
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function Mt, vertex update function Ut, and readout func-
tion R used. Note one could also learn edge features in
an MPNN by introducing hidden states for all edges in the
graph ht
evw
and updating them analogously to equations 1
and 2. Of the existing MPNNs, only Kearnes et al. (2016)
Recurrent Unit introduced in Cho et al. (2014). This work
used weight tying, so the same update function is used at
each time step t. Finally,
R =
X
v2V
⇣
i(h(T )
v , h0
v)
⌘ ⇣
j(h(T )
v )
⌘
(4)
where i and j are neural networks, and denotes element-
wise multiplication.
Interaction Networks, Battaglia et al. (2016)
This work considered both the case where there is a tar-
get at each node in the graph, and where there is a graph
level target. It also considered the case where there are
node level effects applied at each time step, in such a
case the update function takes as input the concatenation
(hv, xv, mv) where xv is an external vector representing
some outside influence on the vertex v. The message func-
tion M(hv, hw, evw) is a neural network which takes the
concatenation (hv, hw, evw). The vertex update function
U(hv, xv, mv) is a neural network which takes as input
the concatenation (hv, xv, mv). Finally, in the case where
there is a graph level output, R = f(
P
v2G
hT
v ) where f is
a neural network which takes the sum of the final hidden
states hT
v . Note the original work only defined the model
for T = 1.
Molecular Graph Convolutions, Kearnes et al. (2016)
'$&2
'$&4
Update Function:
+"(ℎ$
" , ,$
"-.)
42. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n + 6 + 6 6 C : ++ 1- ,)- 2
0 6 LFG+ 0 N GU
• 6 6
6 (
! ℎ#
$
% ∈ ' = tanh( ∑# / 0 ℎ#
$
, ℎ#
2
⊙ tanh 4 ℎ#
$
, ℎ#
2
)
• 0, 4( I / 0 ℎ#
$
, ℎ#
2
( 6 I R
43. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n ( , : 0 (, +1 0
[ T ] US
• ) 0 0 7: 0
) 0 :1 7 :
!" ℎ$
" , ℎ&
" , '$& = tanh -./ -/.ℎ0
" + 23 ⊙ -5.'$0 + 26
• -./, -/., -5. D23, 26 M N
• 20 :1 7 : 7" ℎ$
"
, 8$
"93
= ℎ$
"
+ 8$
"93
Message Function:
!"(ℎ$
"
, ℎ&;
"
, '$&;
)
Σ
Message Function:
!"(ℎ$
" , ℎ&;
" , '$&;
)
Neural Message Passing f
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
Neural Message Passing for Quantum Chemistry
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function Mt, vertex update function Ut, and readout func-
tion R used. Note one could also learn edge features in
an MPNN by introducing hidden states for all edges in the
graph ht
evw
and updating them analogously to equations 1
and 2. Of the existing MPNNs, only Kearnes et al. (2016)
Recurrent Unit introduced in Cho et al. (2014). This work
used weight tying, so the same update function is used at
each time step t. Finally,
R =
X
v2V
⇣
i(h(T )
v , h0
v)
⌘ ⇣
j(h(T )
v )
⌘
(4)
where i and j are neural networks, and denotes element-
wise multiplication.
Interaction Networks, Battaglia et al. (2016)
This work considered both the case where there is a tar-
get at each node in the graph, and where there is a graph
level target. It also considered the case where there are
node level effects applied at each time step, in such a
case the update function takes as input the concatenation
(hv, xv, mv) where xv is an external vector representing
some outside influence on the vertex v. The message func-
tion M(hv, hw, evw) is a neural network which takes the
concatenation (hv, hw, evw). The vertex update function
U(hv, xv, mv) is a neural network which takes as input
the concatenation (hv, xv, mv). Finally, in the case where
there is a graph level output, R = f(
P
v2G
hT
v ) where f is
a neural network which takes the sum of the final hidden
states hT
v . Note the original work only defined the model
for T = 1.
'$&;
'$&=
Update Function:
7"(ℎ$
" , 8$
"93)
44. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n (22: ,2 )2 7 )2 (,)) +0 ) 2
N D
• 2 1 : 2
2 1 0 ! ℎ#
$
% ∈ ' = ∑# NN(ℎ#
$
)
45. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n + 0 2 2 E E , - (
[ 1 6 G 2 2 M
• EE6 C6EE C 6E
[ EE6 :G 7 ) !" ℎ$
" , ℎ&
" , '$& = )('$+)ℎ&
"
• )('$+)) N S R '$+ M IL00
[ C 6 :G 7 ) -" ℎ$
" , .$
"/0 = GRU ℎ$
" , .$
"/0
• ,,00 - 1 U
Message Function:
!"(ℎ$
"
, ℎ&4
"
, '$&4
)
Σ
Message Function:
!"(ℎ$
" , ℎ&4
" , '$&4
)
Neural Message Passing f
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
Neural Message Passing for Quantum Chemistry
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function Mt, vertex update function Ut, and readout func-
tion R used. Note one could also learn edge features in
an MPNN by introducing hidden states for all edges in the
graph ht
evw
and updating them analogously to equations 1
and 2. Of the existing MPNNs, only Kearnes et al. (2016)
Recurrent Unit introduced in Cho et al. (2014). This work
used weight tying, so the same update function is used at
each time step t. Finally,
R =
X
v2V
⇣
i(h(T )
v , h0
v)
⌘ ⇣
j(h(T )
v )
⌘
(4)
where i and j are neural networks, and denotes element-
wise multiplication.
Interaction Networks, Battaglia et al. (2016)
This work considered both the case where there is a tar-
get at each node in the graph, and where there is a graph
level target. It also considered the case where there are
node level effects applied at each time step, in such a
case the update function takes as input the concatenation
(hv, xv, mv) where xv is an external vector representing
some outside influence on the vertex v. The message func-
tion M(hv, hw, evw) is a neural network which takes the
concatenation (hv, hw, evw). The vertex update function
U(hv, xv, mv) is a neural network which takes as input
the concatenation (hv, xv, mv). Finally, in the case where
there is a graph level output, R = f(
P
v2G
hT
v ) where f is
a neural network which takes the sum of the final hidden
states hT
v . Note the original work only defined the model
for T = 1.
'$&4
'$&5
Update Function:
-"(ℎ$
" , .$
"/0)
46. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n + 0 2 2 C C , - (
1 6 E L2 2
• 1 6 E :E 7 )
! ℎ#
$
% ∈ ' = set2set ℎ#
$
% ∈ '
C C G6 C - 1 -.
∗
RM00M SLIN
size of the set, and which is order invariant. In the next sections, we explain such a modification,
which could also be seen as a special case of a Memory Network (Weston et al., 2015) or Neural
Turing Machine (Graves et al., 2014) – with a computation flow as depicted in Figure 1.
4.2 ATTENTION MECHANISMS
Neural models with memories coupled to differentiable addressing mechanism have been success-
fully applied to handwriting generation and recognition (Graves, 2012), machine translation (Bah-
danau et al., 2015a), and more general computation machines (Graves et al., 2014; Weston et al.,
2015). Since we are interested in associative memories we employed a “content” based attention.
This has the property that the vector retrieved from our memory would not change if we randomly
shuffled the memory. This is crucial for proper treatment of the input set X as such. In particular,
our process block based on an attention mechanism uses the following:
qt = LSTM(q⇤
t 1) (3)
ei,t = f(mi, qt) (4)
ai,t =
exp(ei,t)
P
j exp(ej,t)
(5)
rt =
X
i
ai,tmi (6)
q⇤
t = [qt rt] (7)
Read
Process Write
Figure 1: The Read-Process-and-Write model.
where i indexes through each memory vector mi (typically equal to the cardinality of X), qt is
a query vector which allows us to read rt from the memories, f is a function that computes a
single scalar from mi and qt (e.g., a dot product), and LSTM is an LSTM which computes a
recurrent state but which takes no inputs. q⇤
t is the state which this LSTM evolves, and is formed
by concatenating the query qt with the resulting attention readout rt. t is the index which indicates
V ) G6 C - 1
47. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n
•
•
3
Figure 2: Illustration of SchNet with an architectural overview (left), the interaction block (middle)
and the continuous-filter convolution with filter-generating network (right). The shifted softplus is
defined as ssp(x) = ln(0.5ex
+ 0.5).
: https://www.slideshare.net/KazukiFujikawa/schnet-a-continuousfilter-convolutional-neural-network-for-modeling-quantum-
interactions
48. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n 2 . 727 . + )2 2 7 .2 ) )2
C
eft), the interaction block (middle)
ork (right). The shifted softplus is
Zi
+ (+ !"# = %" − %#
(a) 1st
interaction block (b) 2nd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte
SchNet trained on molecular dynamics of ethanol. Negative values ar
Filter-generating networks The cfconv layer including its filter-ge
at the right panel of Fig. 2. In order to satisfy the requirements for
we restrict our filters for the cfconv layers to be rotationally invarian
obtained by using interatomic distances
dij = kri rjk
as input for the filter network. Without further processing, the filters w
a neural network after initialization is close to linear. This leads to
training that is hard to overcome. We avoid this by expanding the dista
ek(ri rj) = exp( kdij µkk2
)
located at centers 0Å µk 30Å every 0.1Å with = 10Å. This is
occurring in the data sets are covered by the filters. Due to this addit
filters are less correlated leading to a faster training procedure. Choos
to reducing the resolution of the filter, while restricting the range of
filter size in a usual convolutional layer. An extensive evaluation of th
left for future work. We feed the expanded distances into two dense l
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction
an ethanol molecular dynamics trajectory. We observe how each filter
interatomic distances. This enables its interaction block to update the re
radial environment of each atom. The sequential updates from three in
to construct highly complex many-body representations in the spirit o
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecula
an energy-conserving force model by differentiating the energy mode
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
(Z1, . . . , Zn, r
+ + -. + 2 3 7
+ (+
Zj’ Zj
+ + -. + 2 3 7
+ (+
+
49. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1 + 1 ,
n 1 12 .7 , 2 1 , ,
p e i Co m
eft), the interaction block (middle)
ork (right). The shifted softplus is
Zi
( ).3+.- ) !"# = %" − %#
(a) 1st
interaction block (b) 2nd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte
SchNet trained on molecular dynamics of ethanol. Negative values ar
Filter-generating networks The cfconv layer including its filter-ge
at the right panel of Fig. 2. In order to satisfy the requirements for
we restrict our filters for the cfconv layers to be rotationally invarian
obtained by using interatomic distances
dij = kri rjk
as input for the filter network. Without further processing, the filters w
a neural network after initialization is close to linear. This leads to
training that is hard to overcome. We avoid this by expanding the dista
ek(ri rj) = exp( kdij µkk2
)
located at centers 0Å µk 30Å every 0.1Å with = 10Å. This is
occurring in the data sets are covered by the filters. Due to this addit
filters are less correlated leading to a faster training procedure. Choos
to reducing the resolution of the filter, while restricting the range of
filter size in a usual convolutional layer. An extensive evaluation of th
left for future work. We feed the expanded distances into two dense l
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction
an ethanol molecular dynamics trajectory. We observe how each filter
interatomic distances. This enables its interaction block to update the re
radial environment of each atom. The sequential updates from three in
to construct highly complex many-body representations in the spirit o
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecula
an energy-conserving force model by differentiating the energy mode
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
(Z1, . . . , Zn, r
-. . 01 .- 2
.3+.-
Zj’ Zj
-. . 01 .- 2
.3+.-
+
2 7 0 7- 2 0 7
'( = 0.1Å, '. = 0.2Å, … '122 = 30Å
4 = 10Å 7+ ( b
!"# hC' d
h n lC h
0 c f
50. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1 + 1 ,
n 327 2 3 - 7 32 3 7 32 - 32
b a
eft), the interaction block (middle)
ork (right). The shifted softplus is
Zi
) + !"# = %" − %#
(a) 1st
interaction block (b) 2nd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte
SchNet trained on molecular dynamics of ethanol. Negative values ar
Filter-generating networks The cfconv layer including its filter-ge
at the right panel of Fig. 2. In order to satisfy the requirements for
we restrict our filters for the cfconv layers to be rotationally invarian
obtained by using interatomic distances
dij = kri rjk
as input for the filter network. Without further processing, the filters w
a neural network after initialization is close to linear. This leads to
training that is hard to overcome. We avoid this by expanding the dista
ek(ri rj) = exp( kdij µkk2
)
located at centers 0Å µk 30Å every 0.1Å with = 10Å. This is
occurring in the data sets are covered by the filters. Due to this addit
filters are less correlated leading to a faster training procedure. Choos
to reducing the resolution of the filter, while restricting the range of
filter size in a usual convolutional layer. An extensive evaluation of th
left for future work. We feed the expanded distances into two dense l
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction
an ethanol molecular dynamics trajectory. We observe how each filter
interatomic distances. This enables its interaction block to update the re
radial environment of each atom. The sequential updates from three in
to construct highly complex many-body representations in the spirit o
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecula
an energy-conserving force model by differentiating the energy mode
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
(Z1, . . . , Zn, r
+ 2 . -7 + 3-7
) +
Zj’ Zj
+ 2 . -7 + 3-7
) +
+
2 7 0 7- 2 0 7
C ) + 73
c
( 7 (7 32
) + 73
51. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n 2 . 727 . + )2 2 7 .2 ) )2
C
eft), the interaction block (middle)
ork (right). The shifted softplus is
Zi
+ (+ !"# = %" − %#
(a) 1st
interaction block (b) 2nd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte
SchNet trained on molecular dynamics of ethanol. Negative values ar
Filter-generating networks The cfconv layer including its filter-ge
at the right panel of Fig. 2. In order to satisfy the requirements for
we restrict our filters for the cfconv layers to be rotationally invarian
obtained by using interatomic distances
dij = kri rjk
as input for the filter network. Without further processing, the filters w
a neural network after initialization is close to linear. This leads to
training that is hard to overcome. We avoid this by expanding the dista
ek(ri rj) = exp( kdij µkk2
)
located at centers 0Å µk 30Å every 0.1Å with = 10Å. This is
occurring in the data sets are covered by the filters. Due to this addit
filters are less correlated leading to a faster training procedure. Choos
to reducing the resolution of the filter, while restricting the range of
filter size in a usual convolutional layer. An extensive evaluation of th
left for future work. We feed the expanded distances into two dense l
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction
an ethanol molecular dynamics trajectory. We observe how each filter
interatomic distances. This enables its interaction block to update the re
radial environment of each atom. The sequential updates from three in
to construct highly complex many-body representations in the spirit o
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecula
an energy-conserving force model by differentiating the energy mode
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
(Z1, . . . , Zn, r
+ + -. + 2 3 7
+ (+
Zj’ Zj
+ + -. + 2 3 7
+ (+
+
52. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n
e D
• I a I c
N
• I B
T
SchNet with an architectural overview (left), the interaction block (middle)
r convolution with filter-generating network (right). The shifted softplus is
0.5ex
+ 0.5).
(a) 1st
interaction block (b) 2nd
interaction block (c) 3rd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filters in each interaction block of
SchNet trained on molecular dynamics of ethanol. Negative values are blue, positive values are red.
Filter-generating networks The cfconv layer including its filter-generating network are depicted
at the right panel of Fig. 2. In order to satisfy the requirements for modeling molecular energies,
53. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n
, - S
1
1 N
Figure 2: Illustration of SchNet with an architectural overview (left), the interaction block (middle)
and the continuous-filter convolution with filter-generating network (right). The shifted softplus is
defined as ssp(x) = ln(0.5ex
+ 0.5).
4.1 Architecture
54. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n 011
a l C [C
• m: aeoh
• m: aeohC : 2 7
• ! C C L
• mi : ] , +
100,000 0.34 0.84 – –
110,462 0.31 – 0.45 0.33
We include the total energy E as well as forces Fi in the training loss to train a neural network that
performs well on both properties:
`( ˆE, (E, F1, . . . , Fn)) = kE ˆEk2
+
⇢
n
nX
i=0
Fi
@ ˆE
@Ri
! 2
. (5)
This kind of loss has been used before for fitting a restricted potential energy surfaces with MLPs [36].
n our experiments, we use ⇢ = 0 in Eq. 5 for pure energy based training and ⇢ = 100 for combined
energy and force training. The value of ⇢ was optimized empirically to account for different scales of
energy and forces.
Due to the relation of energies and forces reflected in the model, we expect to see improved gen-
eralization, however, at a computational cost. As we need to perform a full forward and backward
pass on the energy model to obtain the forces, the resulting force model is twice as deep and, hence,
equires about twice the amount of computation time.
Even though the GDML model captures this relationship between energies and forces, it is explicitly
optimized to predict the force field while the energy prediction is a by-product. Models such as
circular fingerprints [15], molecular graph convolutions or message-passing neural networks[19] for
property prediction across chemical compound space are only concerned with equilibrium molecules,
.e., the special case where the forces are vanishing. They can not be trained with forces in a similar
manner, as they include discontinuities in their predicted potential energy surface caused by discrete
binning or the use of one-hot encoded bond type information.
5 Experiments and results
0.59 0.94 – –
0.34 0.84 – –
0.31 – 0.45 0.33
as well as forces Fi in the training loss to train a neural network that
ies:
. . , Fn)) = kE ˆEk2
+
⇢
n
nX
i=0
Fi
@ ˆE
@Ri
! 2
. (5)
before for fitting a restricted potential energy surfaces with MLPs [36].
= 0 in Eq. 5 for pure energy based training and ⇢ = 100 for combined
value of ⇢ was optimized empirically to account for different scales of
es and forces reflected in the model, we expect to see improved gen-
mputational cost. As we need to perform a full forward and backward
btain the forces, the resulting force model is twice as deep and, hence,
nt of computation time.
l captures this relationship between energies and forces, it is explicitly
e field while the energy prediction is a by-product. Models such as
ecular graph convolutions or message-passing neural networks[19] for
rgy predictions in kcal/mol on the QM9 data set with given
.
DTNN [18] enn-s2s [19] enn-s2s-ens5 [19]
0.94 – –
0.84 – –
– 0.45 0.33
as forces Fi in the training loss to train a neural network that
= kE ˆEk2
+
⇢
n
nX
i=0
Fi
@ ˆE
@Ri
! 2
. (5)
or fitting a restricted potential energy surfaces with MLPs [36].
q. 5 for pure energy based training and ⇢ = 100 for combined
⇢ was optimized empirically to account for different scales of
rces reflected in the model, we expect to see improved gen-
al cost. As we need to perform a full forward and backward
forces, the resulting force model is twice as deep and, hence,
mputation time.
s this relationship between energies and forces, it is explicitly
hile the energy prediction is a by-product. Models such as
aph convolutions or message-passing neural networks[19] for
mpound space are only concerned with equilibrium molecules,
re vanishing. They can not be trained with forces in a similar
in their predicted potential energy surface caused by discrete
bond type information.
ek(ri rj) = exp( kdij µkk )
located at centers 0Å µk 30Å every 0.1Å with = 10Å. This is chosen such that all distances
occurring in the data sets are covered by the filters. Due to this additional non-linearity, the initial
filters are less correlated leading to a faster training procedure. Choosing fewer centers corresponds
to reducing the resolution of the filter, while restricting the range of the centers corresponds to the
filter size in a usual convolutional layer. An extensive evaluation of the impact of these variables is
left for future work. We feed the expanded distances into two dense layers with softplus activations
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction blocks of SchNet trained on
an ethanol molecular dynamics trajectory. We observe how each filter emphasizes certain ranges of
interatomic distances. This enables its interaction block to update the representations according to the
radial environment of each atom. The sequential updates from three interaction blocks allow SchNet
to construct highly complex many-body representations in the spirit of DTNNs [18] while keeping
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecular energy, so that we can obtain
an energy-conserving force model by differentiating the energy model w.r.t. the atom positions
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
@ri
(Z1, . . . , Zn, r1, . . . , rn). (4)
Chmiela et al. [17] pointed out that this leads to an energy-conserving force-field by construction.
As SchNet yields rotationally invariant energy predictions, the force predictions are rotationally
equivariant by construction. The model has to be at least twice differentiable to allow for gradient
descent of the force loss. We chose a shifted softplus ssp(x) = ln(0.5ex
+ 0.5) as non-linearity
throughout the network in order to obtain a smooth potential energy surface. The shifting ensures that
55. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n G1
+ S lM[ nrU U eFca
• U Qm U G uSUnr U i F U
TN
• S r s SI N L
t
• 00 52 :D 07 9 9 5, 9 -
9 9 G9 U] hi
Table 1: Mean absolute errors for energy predictions in kcal/mol on the QM9 data set with given
training set size N. Best model in bold.
N SchNet DTNN [18] enn-s2s [19] enn-s2s-ens5 [19]
50,000 0.59 0.94 – –
100,000 0.34 0.84 – –
110,462 0.31 – 0.45 0.33
We include the total energy E as well as forces Fi in the training loss to train a neural network that
performs well on both properties:
56. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1 1 2 1 10 7 7
1 , +
n C G N C
• G N C / c
• G N C /
• / a
Merge
Fully-
Connected
Graph
Convolution
Classification
Ligand Protein
Graph
Receptor Protein
Graph
Graph
Convolution
Graph
Convolution
Graph
Convolution
Residue
Representation
Residue Pair
Representation
R1
R2
R3
R1
R2
R3
R1
R2
R3
R1
R2
R3
R1
R2
R3
R1
R1
R1
R2
R2
R2
R3
R3
R3
Figure 2: An overview of the pairwise classification architecture. Each neighborhood of a residue i
proteins is processed using one or more graph convolution layers, with weight sharing between le
network. The activations generated by the convolutional layers are merged by concatenating them, fol
Node
Residue
Conservation /
Composition
Accessible
Surface Area
Residue Depth
Protrusion Index
Edge
Distance
Angle
protein
1: Graph convolution on protein structures. Left: Each residue in a protein is a node in a graph where the
rhood of a node is the set of neighboring nodes in the protein structure; each node has features computed
amino acid sequence and structure, and edges have features describing the relative distance and angle
n residues. Right: Schematic description of the convolution operator which has as its receptive field a set
hboring residues, and produces an activation which is associated with the center residue.
orhood of size k, Fin input features and Fout output features is O(kFinFoutn). Construction of
57. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1 1 2 1 10 7 7
1 , +
n : O F C GAE M
/ 3 N
• 3 /
• 3 / /
•
of the neighbors of a node. Our objective is to design convolution operators that can be applied to
graphs without a regular structure, and without imposing a particular order on the neighbors of a
given node. To summarize, we would like to learn a mapping at each node in the graph which has
the form: zi = W (xi, {xn1
, . . . , xnk
}), where {n1, . . . , nk} are the neighbors of node i that define
the receptive field of the convolution, is a non-linear activation function, and W are its learned
parameters; the dependence on the neighboring nodes as a set represents our intention to learn a
function that is order-independent. We present the following two realizations of this operator that
provides the output of a set of filters in a neighborhood of a node of interest that we refer to as the
"center node":
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
xj + b
◆
, (1)
where Ni is the set of neighbors of node i, W C
is the weight matrix associated with the center node,
W N
is the weight matrix associated with neighboring nodes, and b is a vector of biases, one for each
filter. The dimensionality of the weight matrices is determined by the dimensionality of the inputs
and the number of filters. The computational complexity of this operator on a graph with n nodes, a
2
Figure 1: Graph convolution on protein structures. Left: Each residue in a protein is a node in a graph where
neighborhood of a node is the set of neighboring nodes in the protein structure; each node has features comp
from its amino acid sequence and structure, and edges have features describing the relative distance and a
between residues. Right: Schematic description of the convolution operator which has as its receptive field
of neighboring residues, and produces an activation which is associated with the center residue.
neighborhood of size k, Fin input features and Fout output features is O(kFinFoutn). Constructio
the neighborhood is straightforward using a preprocessing step that takes O(n2
log n).
In order to provide for some differentiation between neighbors, we incorporate features on the ed
between each neighbor and the center node as follows:
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
xj +
1
|Ni|
X
j2Ni
W E
Aij + b
◆
,
where W E
is the weight matrix associated with edge features.
For comparison with order-independent methods we propose an order-dependent method, wh
order is determined by distance from the center node. In this method each neighbor has unique we
matrices for nodes and edges:
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
j xj +
1
|Ni|
X
j2Ni
W E
j Aij + b
◆
.
Here W N
j /W E
j are the weight matrices associated with the jth
node or the edges connecting to the
nodes, respectively. This operator is inspired by the PATCHY-SAN method of Niepert et al. [16].
more flexible than the order-independent convolutional operators, allowing the learning of distinct
between neighbors at the cost of significantly more parameters.
Multiple layers of these graph convolution operators can be used, and this will have the ef
of learning features that characterize the graph at increasing levels of abstraction, and will
allow information to propagate through the graph, thereby integrating information across region
increasing size. Furthermore, these operators are rotation-invariant if the features have this prop
In convolutional networks, inputs are often downsampled based on the size and stride of the recep
field. It is also common to use pooling to further reduce the size of the input. Our graph opera
on the other hand maintain the structure of the graph, which is necessary for the protein interf
prediction problem, where we classify pairs of nodes from different graphs, rather than en
graphs. Using convolutional architectures that use only convolutional layers without downsamplin
common practice in the area of graph convolutional networks, especially if classification is perform
neighborhood of size k, Fin input features and Fout output features is O(kFinFoutn). Construction of
the neighborhood is straightforward using a preprocessing step that takes O(n2
log n).
In order to provide for some differentiation between neighbors, we incorporate features on the edges
between each neighbor and the center node as follows:
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
xj +
1
|Ni|
X
j2Ni
W E
Aij + b
◆
, (2)
where W E
is the weight matrix associated with edge features.
For comparison with order-independent methods we propose an order-dependent method, where
order is determined by distance from the center node. In this method each neighbor has unique weight
matrices for nodes and edges:
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
j xj +
1
|Ni|
X
j2Ni
W E
j Aij + b
◆
. (3)
Here W N
j /W E
j are the weight matrices associated with the jth
node or the edges connecting to the jth
nodes, respectively. This operator is inspired by the PATCHY-SAN method of Niepert et al. [16]. It is
more flexible than the order-independent convolutional operators, allowing the learning of distinctions
between neighbors at the cost of significantly more parameters.
Multiple layers of these graph convolution operators can be used, and this will have the effect
of learning features that characterize the graph at increasing levels of abstraction, and will also
allow information to propagate through the graph, thereby integrating information across regions of
increasing size. Furthermore, these operators are rotation-invariant if the features have this property.
In convolutional networks, inputs are often downsampled based on the size and stride of the receptive
field. It is also common to use pooling to further reduce the size of the input. Our graph operators
on the other hand maintain the structure of the graph, which is necessary for the protein interface
prediction problem, where we classify pairs of nodes from different graphs, rather than entire
graphs. Using convolutional architectures that use only convolutional layers without downsampling is
common practice in the area of graph convolutional networks, especially if classification is performed
at the node or edge level. This practice has support from the success of networks without pooling
layers in the realm of object recognition [23]. The downside of not downsampling is higher memory
and computational costs.
Related work. Several authors have recently proposed graph convolutional operators that generalize
Method Convolutional Layers
1 2 3 4
No Convolution 0.812 (0.007) 0.810 (0.006) 0.808 (0.006) 0.796 (0.006)
Diffusion (DCNN) (2 hops) [5] 0.790 (0.014) – – –
Diffusion (DCNN) (5 hops) [5]) 0.828 (0.018) – – –
Single Weight Matrix (MFN [9]) 0.865 (0.007) 0.871 (0.013) 0.873 (0.017) 0.869 (0.017)
Node Average (Equation (1)) 0.864 (0.007) 0.882 (0.007) 0.891 (0.005) 0.889 (0.005)
Node and Edge Average (Equation (2)) 0.876 (0.005) 0.898 (0.005) 0.895 (0.006) 0.889 (0.007)
DTNN [21] 0.867 (0.007) 0.880 (0.007) 0.882 (0.008) 0.873 (0.012)
Order Dependent (Equation (3)) 0.854 (0.004) 0.873 (0.005) 0.891 (0.004) 0.889 (0.008)
Table 2: Median area under the receiver operating characteristic curve (AUC) across all complexes in the
test set for various graph convolutional methods. Results shown are the average and standard deviation over
ten runs with different random seeds. Networks have the following number of filters for 1, 2, 3, and 4 layers
before merging, respectively: (256), (256, 512), (256, 256, 512), (256, 256, 512, 512). The exception is the
DTNN method, which by necessity produces an output which is has the same dimensionality as its input. Unlike
the other methods, diffusion convolution performed best with an RBF with a standard deviation of 2Å. After
merging, all networks have a dense layer with 512 hidden units followed by a binary classification layer. Bold
faced values indicate best performance for each method.
58. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n S . .
o a ib a e G
o a e c : d
•
o l a . . P
•
o n .
o g e .
o hm .
Figure 1: Scene graphs are defined by the objects in an image (vertices) and their interactions (edges).
The ability to express information about the connections between objects make scene graphs a useful
representation for many computer vision tasks including captioning and visual question answering.
59. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n C
2
1
1
1 2
2 . N
Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two
heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are
extracted from the pixel locations of top activations and fed through fully connected networks to
60. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n
e 2 3,
•
D 3, c I
I 3 1
c I , I
I 3 , .d
Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two
heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are
extracted from the pixel locations of top activations and fed through fully connected networks to
61. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n I
/ 2
•
1
•
2 D 2 D
•
D L
not indicate which vertex serves as the source and which serves as the destination, nor does it
disambiguate between pairs of vertices that happen to share the same midpoint.
To train the network to produce a coherent set of embeddings we build off of the loss penalty used in
[20]. During training, we have a ground truth set of annotations defining the unique objects in the
scene and the edges between these objects. This allows us to enforce two penalties: that an edge points
to a vertex by matching its output embedding as closely as possible, and that the embedding vectors
produced for each vertex are sufficiently different. We think of the first as “pulling together” all
references to a single vertex, and the second as “pushing apart” the references to different individual
vertices.
We consider an embedding hi 2 Rd
produced for a vertex vi 2 V . All edges that connect to this
vertex produce a set of embeddings h0
ik, k = 1, ..., Ki where Ki is the total number of references to
that vertex. Given an image with n objects the loss to “pull together” these embeddings is:
Lpull =
1
Pn
i=1 Ki
nX
i=1
KiX
k=1
(hi h0
ik)2
To “push apart” embeddings across different vertices we first used the penalty described in [20],
but experienced difficulty with convergence. We tested alternatives and the most reliable loss was a
margin-based penalty similar to [9]:
Lpush =
n 1X
i=1
nX
j=i+1
max(0, m ||hi hj||)
Intuitively, Lpush is at its highest the closer hi and hj are to each other. The penalty drops off sharply
as the distance between hi and hj grows, eventually hitting zero once the distance is greater than a
given margin m. On the flip side, for some edge connected to a vertex vi, the loss Lpull will quickly
grow the further its reference embedding h0
i is from hi.
The two penalties are weighted equally leaving a final associative embedding loss of Lpull + Lpush.
In this work, we use m = 8 and d = 8. Convergence of the network improves greatly after increasing
the dimension d of tags up from 1 as used in [20].
Once the network is trained with this loss, full construction of the graph can be performed with a
trivial postprocessing step. The network produces a pool of vertex and edge detections. For every
edge, we look at the source and destination embeddings and match them to the closest embedding
amongst the detected vertices. Multiple edges may have the same source and target vertices, vs and
vt, and it is also possible for vs to equal vt.
Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two
heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are
extracted from the pixel locations of top activations and fed through fully connected networks to
To train the network to produce a coherent set of embeddings we build off of the loss penalty used in
[20]. During training, we have a ground truth set of annotations defining the unique objects in the
scene and the edges between these objects. This allows us to enforce two penalties: that an edge points
to a vertex by matching its output embedding as closely as possible, and that the embedding vectors
produced for each vertex are sufficiently different. We think of the first as “pulling together” all
references to a single vertex, and the second as “pushing apart” the references to different individual
vertices.
We consider an embedding hi 2 Rd
produced for a vertex vi 2 V . All edges that connect to this
vertex produce a set of embeddings h0
ik, k = 1, ..., Ki where Ki is the total number of references to
that vertex. Given an image with n objects the loss to “pull together” these embeddings is:
Lpull =
1
Pn
i=1 Ki
nX
i=1
KiX
k=1
(hi h0
ik)2
To “push apart” embeddings across different vertices we first used the penalty described in [20],
but experienced difficulty with convergence. We tested alternatives and the most reliable loss was a
margin-based penalty similar to [9]:
Lpush =
n 1X
i=1
nX
j=i+1
max(0, m ||hi hj||)
Intuitively, Lpush is at its highest the closer hi and hj are to each other. The penalty drops off sharply
as the distance between hi and hj grows, eventually hitting zero once the distance is greater than a
given margin m. On the flip side, for some edge connected to a vertex vi, the loss Lpull will quickly
grow the further its reference embedding h0
i is from hi.
The two penalties are weighted equally leaving a final associative embedding loss of Lpull + Lpush.
In this work, we use m = 8 and d = 8. Convergence of the network improves greatly after increasing
the dimension d of tags up from 1 as used in [20].
Once the network is trained with this loss, full construction of the graph can be performed with a
trivial postprocessing step. The network produces a pool of vertex and edge detections. For every
edge, we look at the source and destination embeddings and match them to the closest embedding
amongst the detected vertices. Multiple edges may have the same source and target vertices, vs and
vt, and it is also possible for vs to equal vt.
5
62. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n S
e a@ G P
• d@ C R b c :
• d C b c :
• 3 d G C c :
SGGen (no RPN) SGGen (w/ RPN) SGCls PredCls
R@50 R@100 R@50 R@100 R@50 R@100 R@50 R@100
Lu et al. [18] – – 0.3 0.5 11.8 14.1 27.9 35.0
Xu et al. [26] – – 3.4 4.2 21.7 24.4 44.8 53.0
Our model 6.7 7.8 9.7 11.3 26.5 30.0 68.0 75.2
Table 1: Results on Visual Genome
Figure 3: Predictions on Visual Genome. In the top row, the network must produce all object and
relationship detections directly from the image. The second row includes examples from an easier
version of the task where object detections are provided. Relationships outlined in green correspond
to predictions that correctly matched to a ground truth annotation.
the course of training. We set so = 3 and sr = 6 which is sufficient to completely accommodate the
detection annotations for all but a small fraction of cases.
&
@ c
63. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n 2FV RIOS 7F NJNH
Y NSO O" M" FS L FS LF NJNH VJSI MFMO HMFNSF NF L
NFSVO KR
Y JN LR" : JOL" FS L SDIJNH NFSVO KR O ONF RIOS LF NJNH NDFR JN
9F L 4N O M SJON ODFRRJNH RSFMR
Y NFLL" 5 KF" 6F JN VF RK " N JDI AFMFL OSOS PJD L NFSVO KR O
FV RIOS LF NJNH NDFR JN 9F L 4N O M SJON ODFRRJNH RSFMR
,
Y J" DIJN" N 3 HO 7 ODIFLLF :PSJMJX SJON R MO FL O FV RIOS
LF NJNH
Y 2JNN" 0IFLRF " JFSF CCFFL" N F HF 7F JNF O FL HNORSJD MFS
LF NJNH O RS PS SJON O FFP NFSVO KR J P FP JNS
J . , ( () ,
Y 1 N" N" FS L :NF RIOS JMJS SJON LF NJNH NDFR JN NF L
JN O M SJON P ODFRRJNH R RSFMR ,
Y 2JNN" 0IFLRF " FS L :NF RIOS JR L JMJS SJON LF NJNH J MFS LF NJNH
J P FP JNS J . , - )- ,
64. Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n 3P NF 0 LT JS G L
3GJKCP" 5SQ GL" C J 8CSP J KCQQ C N QQGL D P S L SK AFCKGQ P
P GT NPCNPGL P GT. )
1STCL S " 1 TG " C J 0 LT JS G L J LC PIQ L P NFQ D P JC PLGL
K JCASJ P DGL CPNPGL Q T LACQ GL LCSP J GLD PK G L NP ACQQGL Q Q CKQ
7G" S G " C J 3 C P NF QC SCLAC LCSP J LC PIQ P GT NPCNPGL
P GT. )-(
AFY " PGQ D " C J :S L SK AFCKGA J GLQG F Q DP K CCN CLQ P
LCSP J LC PIQ 8 SPC A KKSLGA G LQ , . (,-
AFY " PGQ D" C J AF8C . A L GLS SQ DGJ CP A LT JS G L J LCSP J
LC PI D P K CJGL S L SK GL CP A G LQ T LACQ GL 8CSP J
4LD PK G L 9P ACQQGL Q CKQ
2 S " JC " C J 9P CGL 4L CPD AC 9PC GA G L SQGL 3P NF 0 LT JS G L J
8C PIQ T LACQ GL 8CSP J 4LD PK G L 9P ACQQGL Q CKQ
8C CJJ" JC L P " L 5G 1CL 9G CJQ P NFQ QQ AG GTC
CK C GL T LACQ GL 8CSP J 4LD PK G L 9P ACQQGL Q CKQ