SlideShare a Scribd company logo
1 of 64
Download to read offline
AI System Dept.
System Management Unit
Kazuki Fujikawa
7 202
& 2 1 70 2
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n /
/ -
/
n /
/ -
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n - -
/ -
/
n /
/ -
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n FV b F caF iumw
- C 1 8 C / 1 8
iuf
n t pr m LI F So Nyf L nS
lRL h de
/8 1AA 1
/8 1 1 7
1A . 1 7
skg 01 CA 1 1 7 1 1 8 A 8 21 2 : C 2 7 1 7 1
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n F N I F SF
- C 1 8 C / 1 8
RIL V
n / -
/8 1AA 1
/8 1 1 7
1A . 1 7
01 CA 1 1 7 1 1 8 A 8 21 2 : C 2 7 1 7 1
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Learning to Learn
● What is Meta Learning / Learning to Learn?
○ Go beyond train/test from same distribution.
○ Task between train/test changes, so model has to “learn to learn”
● Datasets
132
Lake et al,
2013, 2015
I/O
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n -
11
T’: Meta testing taskT: Meta training task
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
• - :
• -
https://www.cs.toronto.edu/~kriz/cifar.html
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n - .
12
T’: Meta testing taskT: Meta training task
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
• T
T :
• M a
N
https://www.cs.toronto.edu/~kriz/cifar.html
• T’
:
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n - - 3 ,
13
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
L’: Label set
, 3
• :
- 3 ,
automobile
cat
deer
https://www.cs.toronto.edu/~kriz/cifar.html
T’: Meta testing taskT: Meta training task
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n -
14
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
L’: Label set
S’: Support set : Query
automobile
cat
deer
1- - -
1- 1- 1- 1-
ˆx
https://www.cs.toronto.edu/~kriz/cifar.html
T’: Meta testing taskT: Meta training task
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
-
n - - 1 3
15
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
airplane
automobile
bird
cat
deer
L’: Label set
S’: Support set : Query
automobile
cat
deer
31
- 31 31
ˆx
• T
LN
:!"
https://www.cs.toronto.edu/~kriz/cifar.html
T’: Meta testing taskT: Meta training task
1. Sample label set L from T
2. Sample a few images as support set S from L
3. Sample a few images as batch B from L
4. Optimize batch, Go to 1
S
L
Go
Ger
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Learning to Learn
134
Losses
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
)) ) ) ( ) ( ) )
Losses
Learning to Learn
Model Based
● Santoro et al. ’16
● Duan et al. ’17
● Wang et al. ‘17
● Munkhdalai & Yu ‘17
● Mishra et al. ‘17
Optimization Based
● Schmidhuber ’87, ’92
● Bengio et al. ’90, ‘92
● Hochreiter et al. ’01
● Li & Malik ‘16
● Andrychowicz et al. ’16
● Ravi & Larochelle ‘17
● Finn et al ‘17
Metric Based
● Koch ’15
● Vinyals et al. ‘16
● Snell et al. ‘17
● Shyam et al. ‘17
Trends: Learning to Learn / Meta Learning 135
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
, ( 2 - 16 - ,(--
1 ),+ 0
n z il Ue gL lO Ue g
ilPSo
0 T s e g Ma Rz cd
m
• z m V N ISI cd M 0
) . 1 ).1 p t
z LilO U cd m PRyu PSe g
v
Oriol Vinyals, NIPS 17
Model Based Meta Learning
rn 2 : ) 7 ( ) 7 : A:
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 1 6 2 1 + 1
, +
n V sI Or zVih V O
e msL 1 1 o
V V h E
!", $" v L
2 dg aS !", $" S Oh V t m
w FPO ih NV uI
2 dg aS !", $" S Oh V t m
ie n MOlSw e VpFPOih uL
yu 0 12 : 7 0 2 0 20 -. 2 0 20 . 7: 7
Oriol Vinyals, NIPS 17
Matching Networks, Vinyals et al, NIPS 2016
Metric Based Meta Learning
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 1 0 0 7 72
,+70 + -
n k
M L /-/ F N
LF / - /:!" i
-/ - h F e
•
•
Zc g f a
•
2.1 Notation
In few-shot classification we are given a small support set of N labeled examples S =
{(x1, y1), . . . , (xN , yN )} where each xi 2 RD
is the D-dimensional feature vector of an example
and yi 2 {1, . . . , K} is the corresponding label. Sk denotes the set of examples labeled with class k.
2.2 Model
Prototypical Networks compute an M-dimensional representation ck 2 RM
, or prototype, of each
class through an embedding function f : RD
! RM
with learnable parameters . Each prototype
is the mean vector of the embedded support points belonging to its class:
ck =
1
|Sk|
X
(xi,yi)2Sk
f (xi) (1)
Given a distance function d : RM
⇥ RM
! [0, +1), Prototypical Networks produce a distribution
over classes for a query point x based on a softmax over distances to the prototypes in the embedding
space:
p (y = k | x) =
exp( d(f (x), ck))
P
k0 exp( d(f (x), ck0 ))
(2)
Learning proceeds by minimizing the negative log-probability J( ) = log p (y = k | x) of the
true class k via SGD. Training episodes are formed by randomly selecting a subset of classes from
the training set, then choosing a subset of examples within each class to act as the support set and a
2
ing
ers from few-shot learning in that instead of being given a support set of
given a class meta-data vector vk for each class. These could be determined
d be learned from e.g., raw text [8]. Modifying Prototypical Networks to deal
s straightforward: we simply define ck = g#(vk) to be a separate embedding
. An illustration of the zero-shot procedure for Prototypical Networks as
ot procedure is shown in Figure 1. Since the meta-data vector and query
ent input domains, we found it was helpful empirically to fix the prototype
it length, however we do not constrain the query embedding f.
we performed experiments on Omniglot [18] and the miniImageNet version
with the splits proposed by Ravi and Larochelle [24]. We perform zero-shot
1 version of the Caltech UCSD bird dataset (CUB-200 2011) [34].
ot Classification
et of 1623 handwritten characters collected from 50 alphabets. There are 20
Notation
ew-shot classification we are given a small support set of N labeled examples S =
, y1), . . . , (xN , yN )} where each xi 2 RD
is the D-dimensional feature vector of an example
yi 2 {1, . . . , K} is the corresponding label. Sk denotes the set of examples labeled with class k.
Model
otypical Networks compute an M-dimensional representation ck 2 RM
, or prototype, of each
through an embedding function f : RD
! RM
with learnable parameters . Each prototype
mean vector of the embedded support points belonging to its class:
ck =
1
|Sk|
X
(xi,yi)2Sk
f (xi) (1)
n a distance function d : RM
⇥ RM
! [0, +1), Prototypical Networks produce a distribution
classes for a query point x based on a softmax over distances to the prototypes in the embedding
e:
p (y = k | x) =
exp( d(f (x), ck))
P
k0 exp( d(f (x), ck0 ))
(2)
ning proceeds by minimizing the negative log-probability J( ) = log p (y = k | x) of the
class k via SGD. Training episodes are formed by randomly selecting a subset of classes from
aining set, then choosing a subset of examples within each class to act as the support set and a
2
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 1 0 0 7 72
,+70 + -
n a
i -
• O > ,
• - I h aT AN E
i c e M d F> SgC /
i / > SgC -
Table 1: Few-shot classification accuracies on Omniglot. ⇤
Uses non-standard train/test splits.
5-way Acc. 20-way Acc.
Model Dist. Fine Tune 1-shot 5-shot 1-shot 5-shot
MATCHING NETWORKS [32] Cosine N 98.1% 98.9% 93.8% 98.5%
MATCHING NETWORKS [32] Cosine Y 97.9% 98.7% 93.5% 98.7%
NEURAL STATISTICIAN [7] - N 98.1% 99.5% 93.2% 98.1%
MAML [9]⇤
- N 98.7% 99.9% 95.8% 98.9%
PROTOTYPICAL NETWORKS (OURS) Euclid. N 98.8% 99.7% 96.0% 98.9%
Table 2: Few-shot classification accuracies on miniImageNet. All accuracy results are averaged over
600 test episodes and are reported with 95% confidence intervals. ⇤
Results reported by [24].
5-way Acc.
Model Dist. Fine Tune 1-shot 5-shot
BASELINE NEAREST NEIGHBORS
⇤
Cosine N 28.86 ± 0.54% 49.79 ± 0.79%
MATCHING NETWORKS [32]⇤
Cosine N 43.40 ± 0.78% 51.09 ± 0.71%
MATCHING NETWORKS FCE [32]⇤
Cosine N 43.56 ± 0.84% 55.31 ± 0.73%
META-LEARNER LSTM [24]⇤
- N 43.44 ± 0.77% 60.60 ± 0.71%
MAML [9] - N 48.70 ± 1.84% 63.15 ± 0.91%
PROTOTYPICAL NETWORKS (OURS) Euclid. N 49.42 ± 0.78% 68.20 ± 0.66%
3.2 miniImageNet Few-shot Classification
The miniImageNet dataset, originally proposed by Vinyals et al. [32], is derived from the larger
Figure 3: Comparison showing the effect of distance metric and number of classes per training
episode on 5-way classification accuracy for both Matching Networks and Prototypical Networks
on miniImageNet. The x-axis indicates configuration of the training episodes (way, distance, and
shot), and the y-axis indicates 5-way test accuracy for the corresponding shot. Error bars indicate
95% confidence intervals as computed over 600 test episodes. Note that Matching Networks and
Prototypical Networks are identical in the 1-shot case.
Table 3: Zero-shot classification accuracies on CUB-200.
Model
Image
Features
50-way Acc.
0-shot
ALE [1] Fisher 26.9%
SJE [2] AlexNet 40.3%
SAMPLE CLUSTERING [19] AlexNet 44.3%
SJE [2] GoogLeNet 50.1%
DS-SJE [25] GoogLeNet 50.4%
DA-SJE [25] GoogLeNet 50.9%
SYNTHESIZED CLASSIFIERS [6] GoogLeNet 54.7%
PROTOTYPICAL NETWORKS (OURS) GoogLeNet 54.8%
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n uS D O yIl s m i yOG
F CL M I y
P 0 Sz pS g td OT D D
LM yr yO
• θ" = $" ⊙ θ"&' − )" ⊙ ∇+,-.
ℒ"
θ" = 1 ∗ θ"&' − η ∗ ∇+,-.
ℒ"
0 3" = $" ⊙ 3"&' − )" ⊙ 43"
• $", )" ∇+,-.
ℒ", ℒ", θ"&' nMNoa
θ6V 2 7 2 S D PIN y
+ 7 01 2 1 1
-, ,
Thus, we propose training a meta-learner LSTM to learn an update rule for traini
work. We set the cell state of the LSTM to be the parameters of the learner, or c
candidate cell state ˜ct = r✓t 1 Lt, given how valuable information about the grad
mization. We define parametric forms for it and ft so that the meta-learner can de
values through the course of the updates.
Let us start with it, which corresponds to the learning rate for the updates. We let
it = WI ·
⇥
r✓t 1
Lt, Lt, ✓t 1, it 1
⇤
+ bI ,
meaning that the learning rate is a function of the current parameter value ✓t 1, the
r✓t 1
Lt, the current loss Lt, and the previous learning rate it 1. With this inform
learner should be able to finely control the learning rate so as to train the learne
avoiding divergence.
As for ft, it seems possible that the optimal choice isn’t the constant 1. Intuitive
justify shrinking the parameters of the learner and forgetting part of its previous
if the learner is currently in a bad local optima and needs a large change to esca
correspond to a situation where the loss is high but the gradient is close to zero. Th
for the forget gate is to have it be a function of that information, as well as the previ
forget gate:
ft = WF ·
⇥
r✓t 1 Lt, Lt, ✓t 1, ft 1
⇤
+ bF .
Additionally, notice that we can also learn the initial value of the cell state c0 for the
it as a parameter of the meta-learner. This corresponds to the initial weights of th
3
cover classes not present in any of the datasets in Dmeta train (similarly, we ad
meta-validation set that is used to determine hyper-parameters).
where ✓t 1 are the parameters of the learner after t 1 updates, ↵t is the learn
Lt is the loss optimized by the learner for its tth
update, r✓t 1 Lt is the gradien
respect to parameters ✓t 1, and ✓t is the updated parameters of the learner.
Our key observation that we leverage here is that this update resembles the updat
in an LSTM (Hochreiter & Schmidhuber, 1997)
ct = ft ct 1 + it ˜ct,
if ft = 1, ct 1 = ✓t 1, it = ↵t, and ˜ct = r✓t 1
Lt.
Thus, we propose training a meta-learner LSTM to learn an update rule for train
work. We set the cell state of the LSTM to be the parameters of the learner, or
candidate cell state ˜ct = r✓t 1 Lt, given how valuable information about the gr
mization. We define parametric forms for it and ft so that the meta-learner can d
values through the course of the updates.
Let us start with it, which corresponds to the learning rate for the updates. We let
it = WI ·
⇥
r✓t 1
Lt, Lt, ✓t 1, it 1
⇤
+ bI ,
meaning that the learning rate is a function of the current parameter value ✓t 1, th
r✓t 1
Lt, the current loss Lt, and the previous learning rate it 1. With this infor
learner should be able to finely control the learning rate so as to train the learn
avoiding divergence.
As for ft, it seems possible that the optimal choice isn’t the constant 1. Intuiti
justify shrinking the parameters of the learner and forgetting part of its previou
if the learner is currently in a bad local optima and needs a large change to esc
correspond to a situation where the loss is high but the gradient is close to zero. T
for the forget gate is to have it be a function of that information, as well as the pre
forget gate:
ft = WF ·
⇥
r✓t 1 Lt, Lt, ✓t 1, ft 1
⇤
+ bF .
Additionally, notice that we can also learn the initial value of the cell state c0 for th
it as a parameter of the meta-learner. This corresponds to the initial weights of t
3
Oriol Vinyals, NIPS 17
Figure Credit: Hugo Larochelle
Optimization Based Meta Learning
v e - :7 1 27 : 7 - 2 : 2 2 , . 2 2 :
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
(7 2 - 2 7 +2 ( 2 2
A ( - 0+ ,) - 1
n g r e t r
DF C I r
- r m
,i a n r p
N I , I
a n r
Model-Agnostic Meta-Learning for Fast Adaptatio
Algorithm 1 Model-Agnostic Meta-Learning
Require: p(T ): distribution over tasks
Require: ↵, : step size hyperparameters
1: randomly initialize ✓
2: while not done do
3: Sample batch of tasks Ti ⇠ p(T )
4: for all Ti do
5: Evaluate r✓LTi
(f✓) with respect to K examples
6: Compute adapted parameters with gradient de-
scent: ✓0
i = ✓ ↵r✓LTi
(f✓)
7: end for
8: Update ✓ ✓ r✓
P
Ti⇠p(T ) LTi (f✓0
i
)
9: end while
products, wh
braries such
experiments,
this backwar
which we di
3. Species
In this secti
meta-learnin
forcement le
function and
sented to the
nism can be
meta-learning
learning/adaptation
✓
rL1
rL2
rL3
✓⇤
1 ✓⇤
2
✓⇤
3
Figure 1. Diagram of our model-agnostic meta-learning algo-
rithm (MAML), which optimizes for a representation ✓ that can
quickly adapt to new tasks.
Oriol Vinyals, NIPS 17
Summing Up
Model Based
Metric Based
Optimization Based
Oriol Vinyals, NIPS 17
Examples of Optimization Based Meta Learning
Finn et al, 17
Ravi et al, 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n
3 3
• 3
•
3
3
Hidden layers Hidden layers
Temporal Dropout
Neighborhood Attention
+
Temporal Convolution
Attention over
Demonstration
Demonstration Current State
Action
ABlock# B C D E F G H I J
Attention
over
Current
State
Context Network
Demonstration Network
Manipulation Network
Context Embedding
placed on the table.
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n 13 3 :3
n hp iohp A u
• e v N m rw d
• e d N C 1 **
1 3 3 3
• ls d g dbk T
3 33* 3
• D v N t N a 3
1 3 3 3 3
• dbk a * 3 3 3
Hidden layers Hidden layers
Temporal Dropout
Neighborhood Attention
+
Temporal Convolution
Attention over
Demonstration
Demonstration Current State
Action
ABlock# B C D E F G H I J
Attention
over
Current
State
Context Network
Demonstration Network
Manipulation Network
Context Embedding
One-Shot Imitation Learning
h1 h2 h3 hN
N Attentions
i-th attention is applied
to i-th block vs others
Figure 4. Illustration of the neighborhood attention operation. It
receives a list of embeddings for each block, performs one at-
formance of the policy.
After downsampling the demonstration, we apply a se-
quence of operations, composed of dilated temporal con-
volution (Yu & Koltun, 2016) and neighborhood attention.
4.2.3. CONTEXT NETWORK
The context network is the crux of our model. Illustrated
in Fig. 6, it processes both the current state and the embed-
ding produced by the demonstration network, and outputs
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n ,: ,
wCn u e N L S am
• b , , n ,
• S , ,: MS
, , ,
• d h n ,C , , C , kn
• , , s i
, , , ,
• v T M , , c An ,
xg , r o
• vC S ,C
At
Hidden layers Hidden layers
Temporal Dropout
Neighborhood Attention
+
Temporal Convolution
Attention over
Demonstration
Demonstration Current State
Action
ABlock# B C D E F G H I J
Attention
over
Current
State
Context Network
Demonstration Network
Manipulation Network
Context Embedding
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n , ,
: M NP
• L ,
• L M
Hidden layers Hidden layers
Temporal Dropout
Neighborhood Attention
+
Temporal Convolution
Attention over
Demonstration
Demonstration Current State
Action
ABlock# B C D E F G H I J
Attention
over
Current
State
Context Network
Demonstration Network
Manipulation Network
Context Embedding
Figure 2: Illustration of the network architecture.
In our experiments, we use p = 0.95, which reduces the length of demonstrations by a factor of 20.
During test time, we can sample multiple downsampled trajectories, use each of them to compute
downstream results, and average these results to produce an ensemble estimate. In our experience,
this consistently improves the performance of the policy.
Neighborhood Attention: After downsampling the demonstration, we apply a sequence of opera-
tions, composed of dilated temporal convolution [65] and neighborhood attention. We now describe
this second operation in more detail.
From the figure, we can observe that for the easier tasks with fewer stages, all of the different
conditioning strategies perform equally well and almost perfectly. As the difficulty (number of stages)
increases, however, conditioning on the entire demonstration starts to outperform conditioning on the
final state. One possible explanation is that when conditioned only on the final state, the policy may
struggle about which block it should stack first, a piece of information that is readily accessible from
demonstration, which not only communicates the task, but also provides valuable information to help
accomplish it.
More surprisingly, conditioning on the entire demonstration also seems to outperform conditioning
on the snapshot, which we originally expected to perform the best. We suspect that this is due
to the regularization effect introduced by temporal dropout, which effectively augments the set of
demonstrations seen by the policy during training.
Another interesting finding was that training with behavioral cloning has the same level of performance
as training with DAGGER, which suggests that the entire training procedure could work without
requiring interactive supervision. In our preliminary experiments, we found that injecting noise into
the trajectory collection process was important for behavioral cloning to work well, hence in all
experiments reported here we use noise injection. In practice, such noise can come from natural
human-induced noise through tele-operation, or by artificially injecting additional noise before
applying it on the physical robot.
5.2 Visualization
We visualize the attention mechanisms underlying the main policy architecture to have a better
understanding about how it operates. There are two kinds of attention we are mainly interested in,
one where the policy attends to different time steps in the demonstration, and the other where the
policy attends to different blocks in the current state. Fig. 4 shows some of the attention heatmaps.
(a) Attention over blocks in the current state. (b) Attention over downsampled demonstration.
Figure 4: Visualizing attentions performed by the policy during an entire execution. The task
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n a”
+,
• ph p [
2 1 : 1 : 5
• s o [ nvn R ]
: 0
• nvn a A R R [E
SR aer
1 1
• nv R a A R a
1 :
• S M i F nvn a A
R a
i lga D G
o S M i nvn G D
O RA S” StcbR D
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
0 ,2 0 7 1 - +,
n l
G B R n EF eiD
p B C R
R / 1 /
R hA
E G G A
• e a S D so
tG
per task for training, and maintain a separate set of trajectories and initial configurations to be used
for evaluation. The trajectories are collected using a hard-coded policy.
5.1 Performance Evaluation
1 2 3 4 5 6 7
Number of Stages
0%
20%
40%
60%
80%
100%
AverageSuccessRate
Policy Type
Demo
BC
DAGGER
Snapshot
Final state
(a) Performance on training tasks.
2 4 5 6 7 8
Number of Stages
0%
20%
40%
60%
80%
100%
AverageSuccessRate
Policy Type
Demo
BC
DAGGER
Snapshot
Final state
(b) Performance on test tasks.
Figure 3: Comparison of different conditioning strategies. The darkest bar shows the performance of the
hard-coded policy, which unsurprisingly performs the best most of the time. For architectures that use temporal
dropout, we use an ensemble of 10 different downsampled demonstrations and average the action distributions.
Then for all architectures we use the greedy action for evaluation.
Fig. 3 shows the performance of various architectures. Results for training and test tasks are presented
separately, where we group tasks by the number of stages required to complete them. This is because
tasks that require more stages to complete are typically more challenging. In fact, even our scripted
policy frequently fails on the hardest tasks. We measure success rate per task by executing the greedy
policy (taking the most confident action at every time step) in 100 different configurations, each
conditioned on a different demonstration unseen during training. We report the average success rate
over all tasks within the same group.
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1
7 + 2 , 0 +-
n a c bO n s Mb k pl BM
L It iu M
y dg M . : : : . /
• .: : / .: :. : re : / .: : .: h M
: :. a c bO S
• .: : / .: :. : re :. : .: : _
bO o
y iu
Planning Module
Planning with Visual Foresight [4,5]
User Input/Task Specification
How can we enable robots to learn vision-based
manipulation skills that generalize to new objects & goals?
One-Shot Visual Imitation Learning Planning with Visual Foresight
Can robots reuse data from other tasks to adapt to
new objects from only one visual demonstration?
Our meta-learning approach: Learn to learn many other
tasks using one demo
training sets test sets
meta-training
tasks
{
task 1
task 2
…
…
Meta-testing:
Learn new held-out
task from 1 demo
How can robots acquire general models and skills
using entirely autonomously-collected data?
Meta-Imitation Learning using MAML [1,2]
meta-training time:
- Learn from raw pixel observations 

(rather than task-specific, engineered representations)
- collect data with a diverse range of objects and environments
- reuse data from other objects & tasks when learning to
perform new task
Collect data autonomously
Predict future video for different actions [3,5]
Demo: Robot placing, tasks correspond to different objects.
- program initial motions, provide objects
- record camera images and robot actions
- no object supervision, camera calibration, human
annotation, etc.
Frederik Ebert, Chelsea Finn, Alex Lee, Sergey LevineChelsea Finn*, Tianhe Yu*, Tianhao Zhang, Pieter Abbeel, Sergey Levine
Meta-training:
Meta-training
tasks:
Standard robotics paradigm:
Brittle, hand-engineered pipeline.
RGB-D image
segment objects
estimate pose & physics of segments
optimize action using
estimated poses & physics
Our approach:
input image future predictions
actions:
[1] Finn, Abbeel, Levine. Model-Agnostic Meta-
[2] Finn*, Yu*, Zhang, Abbeel, Levine. One-Shot
Pla
D
G
Use
new objects from only one visual demonstration?
Our meta-learning approach: Learn to learn many other
tasks using one demo
training sets test sets
meta-training
tasks
{
task 1
task 2
…
…
Meta-testing:
Learn new held-out
task from 1 demo
u
Meta-Imitation Learning using MAML [1,2]
val demo
meta-training time:
training demo
meta-training
tasks
meta-test time:
demo of meta-test task
with held-out objects
Co
Pre
Sam
1.
2.
3.
4.
policy architecture
shown in demo
Demo: Robot placing, tasks correspond to different objects.
-
-
-
Meta-training:
One-Shot Imitation
Learning Research
Meta-training
tasks: inp
mh :: .. . . . .- -. :
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ 2 - 1 7 1 21 1 2 1 21
, 0
n Meta-learning	loss:
n Task	loss	=	behavioral	cloning	loss:					[Pomerleau’89,Sammut’92]
Learning	a	One-Shot	Imitator	with	MAML
[Finn*,	Yu*,	Zhang,	Abbeel,	Levine,	2017] Pieter	Abbeel	-- embody.ai	/	UC	Berkeley	/	Gradescope
C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ 2 - 1 7 1 21 1 2 1 21
, 0
C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7
n Meta-training	targets	/	objects
Robot	Experiments:	Learning	to	Place
n Meta-testing	targets	/	objects
1,300	demonstrations	for	meta-training
Pieter	Abbeel	-- embody.ai	/	UC	Berkeley	/	Gradescope[Finn*,	Yu*,	Zhang,	Abbeel,	Levine,	2017]
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ 2 - 1 7 1 21 1 2 1 21
, 0
C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7
Robot	Experiments:	Learning	to	Place
1	demo imitation
Succes rate:	90% Pieter	Abbeel	-- embody.ai	/	UC	Berkeley	/	Gradescope[Finn*,	Yu*,	Zhang,	Abbeel,	Levine,	2017]
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n /
/ -
/
n
/ -
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n p y r
00 100 w n N
hc i Ld l
n d l t sRu a
k j Nov LGbge N L L G6 4
d l t
-66 /62 7G. 2 D C
64 3 7 236:2 2 2 6
65 2
:2 65 2 7
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
( ( ) )
I/O
Why Graphs?
Natural Trend:
1. Fixed inputs / outputs
2. Tensor inputs / outputs
3. Sequential inputs / outputs
4. Graphs / Trees inputs / outputs
140Trends: Graph Networks
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
( ( ) )
Message Passing Neural Networks
144Trends: Graph Networks
INPUT:
● Undirected graph G
● node features h_v (vector)
● edge features e_vw (vector)
MPNN
OUTPUT:
● Graph level target
● Family of neural networks which commute with order of vertices.
● Generalizes a few existing NN based methods.
● Related to the Weisfeiler-Lehman algorithm for GI testing.
Arch
.2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n , 2 2 7 2 , 0)7 : (,+ 1
R N) 2 ( wM]r )7 : [ o
GP e L
• , 2 2 7 2
+ M
• i h sL k gm ]
• M]i h a d k gm uM]
n + M C i h k gm + RI]
• 2 2
, 2 2 7 R u Ni h p L
l sL oP k gm tp ]
ntum Chemistry
Oriol Vinyals 3
George E. Dahl 1
DFT
103
seconds
Message Passing Neural Net
10 2
seconds
E,!0, ...
Targets
)7 : (,+
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
Message Function:
!"(ℎ%
" , ℎ'(
" , )%'(
)
Σ
Message Function:
!"(ℎ%
" , ℎ'(
" , )%'(
)
Neural Message Passing for
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function M , vertex update function U , and readout func-
n + : , : : 1 C 5 2
l icN ]E S eF[Uh
• , :
l , : !" ℎ%
" , ℎ'
" , )%' = ,-.,/0(ℎ'
" , )%')
l 0 5 : 1" ℎ%
" , 2%
"34 = 5(6"
789(%)
2%
"34)
• 6"
789(%)
0D : N ISda deg(:) MNfg L P
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
Neural Message Passing for Quantum Chemistry
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function Mt, vertex update function Ut, and readout func-
tion R used. Note one could also learn edge features in
an MPNN by introducing hidden states for all edges in the
graph ht
evw
and updating them analogously to equations 1
and 2. Of the existing MPNNs, only Kearnes et al. (2016)
Recurrent Unit introduced in Cho et al. (2014). This work
used weight tying, so the same update function is used at
each time step t. Finally,
R =
X
v2V
⇣
i(h(T )
v , h0
v)
⌘ ⇣
j(h(T )
v )
⌘
(4)
where i and j are neural networks, and denotes element-
wise multiplication.
Interaction Networks, Battaglia et al. (2016)
This work considered both the case where there is a tar-
get at each node in the graph, and where there is a graph
level target. It also considered the case where there are
node level effects applied at each time step, in such a
case the update function takes as input the concatenation
(hv, xv, mv) where xv is an external vector representing
some outside influence on the vertex v. The message func-
tion M(hv, hw, evw) is a neural network which takes the
concatenation (hv, hw, evw). The vertex update function
U(hv, xv, mv) is a neural network which takes as input
the concatenation (hv, xv, mv). Finally, in the case where
there is a graph level output, R = f(
P
v2G
hT
v ) where f is
a neural network which takes the sum of the final hidden
states hT
v . Note the original work only defined the model
for T = 1.
Molecular Graph Convolutions, Kearnes et al. (2016)
)%'(
)%'>
Update Function:
1"(ℎ%
" , 2%
"34)
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n ( 5 : D 5 + : C )DE 5D ,02
S R [ ] LI M N (+0 P a
• 1 5 DC 5
1 5 DC D C ! ℎ#
$
% ∈ ' = )( ∑#,- ./)0123 4-ℎ#
-
)
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
1 5 DC D C )( ∑#,- ./)0123 4-ℎ#
-
)
ℎ#
(6)
ℎ78
(6)
ℎ79
(6)
FFFFFF
ℎ78
($)
ℎ79
(:)
ℎ#
($)
5: 5 :
5: 5 :
;#78
;#79
<= = !({ℎ#
($)
|% ∈ '})
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n , , 0 0 F C ,, 00 . - .1 (
2 + 6 M ,12U] h iU g
• CC : CC : C
CC : 6 ) !" ℎ$
" , ℎ&
" , '$& = )*&ℎ&
"
• )*&) Rde fa G a G 6 I [cL N
2 6 ) +" ℎ$
" , ,$
"-. = /0+ ℎ$
" , ,$
"-.
Message Function:
!"(ℎ$
" , ℎ&2
" , '$&2
)
Σ
Message Function:
!"(ℎ$
" , ℎ&2
" , '$&2
)
Neural Message Passing for
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function M , vertex update function U , and readout func-
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
Neural Message Passing for Quantum Chemistry
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function Mt, vertex update function Ut, and readout func-
tion R used. Note one could also learn edge features in
an MPNN by introducing hidden states for all edges in the
graph ht
evw
and updating them analogously to equations 1
and 2. Of the existing MPNNs, only Kearnes et al. (2016)
Recurrent Unit introduced in Cho et al. (2014). This work
used weight tying, so the same update function is used at
each time step t. Finally,
R =
X
v2V
⇣
i(h(T )
v , h0
v)
⌘ ⇣
j(h(T )
v )
⌘
(4)
where i and j are neural networks, and denotes element-
wise multiplication.
Interaction Networks, Battaglia et al. (2016)
This work considered both the case where there is a tar-
get at each node in the graph, and where there is a graph
level target. It also considered the case where there are
node level effects applied at each time step, in such a
case the update function takes as input the concatenation
(hv, xv, mv) where xv is an external vector representing
some outside influence on the vertex v. The message func-
tion M(hv, hw, evw) is a neural network which takes the
concatenation (hv, hw, evw). The vertex update function
U(hv, xv, mv) is a neural network which takes as input
the concatenation (hv, xv, mv). Finally, in the case where
there is a graph level output, R = f(
P
v2G
hT
v ) where f is
a neural network which takes the sum of the final hidden
states hT
v . Note the original work only defined the model
for T = 1.
Molecular Graph Convolutions, Kearnes et al. (2016)
'$&2
'$&4
Update Function:
+"(ℎ$
" , ,$
"-.)
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n + 6 + 6 6 C : ++ 1- ,)- 2
0 6 LFG+ 0 N GU
• 6 6
6 (
! ℎ#
$
% ∈ ' = tanh( ∑# / 0 ℎ#
$
, ℎ#
2
⊙ tanh 4 ℎ#
$
, ℎ#
2
)
• 0, 4( I / 0 ℎ#
$
, ℎ#
2
( 6 I R
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n ( , : 0 (, +1 0
[ T ] US
• ) 0 0 7: 0
) 0 :1 7 :
!" ℎ$
" , ℎ&
" , '$& = tanh -./ -/.ℎ0
" + 23 ⊙ -5.'$0 + 26
• -./, -/., -5. D23, 26 M N
• 20 :1 7 : 7" ℎ$
"
, 8$
"93
= ℎ$
"
+ 8$
"93
Message Function:
!"(ℎ$
"
, ℎ&;
"
, '$&;
)
Σ
Message Function:
!"(ℎ$
" , ℎ&;
" , '$&;
)
Neural Message Passing f
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
Neural Message Passing for Quantum Chemistry
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function Mt, vertex update function Ut, and readout func-
tion R used. Note one could also learn edge features in
an MPNN by introducing hidden states for all edges in the
graph ht
evw
and updating them analogously to equations 1
and 2. Of the existing MPNNs, only Kearnes et al. (2016)
Recurrent Unit introduced in Cho et al. (2014). This work
used weight tying, so the same update function is used at
each time step t. Finally,
R =
X
v2V
⇣
i(h(T )
v , h0
v)
⌘ ⇣
j(h(T )
v )
⌘
(4)
where i and j are neural networks, and denotes element-
wise multiplication.
Interaction Networks, Battaglia et al. (2016)
This work considered both the case where there is a tar-
get at each node in the graph, and where there is a graph
level target. It also considered the case where there are
node level effects applied at each time step, in such a
case the update function takes as input the concatenation
(hv, xv, mv) where xv is an external vector representing
some outside influence on the vertex v. The message func-
tion M(hv, hw, evw) is a neural network which takes the
concatenation (hv, hw, evw). The vertex update function
U(hv, xv, mv) is a neural network which takes as input
the concatenation (hv, xv, mv). Finally, in the case where
there is a graph level output, R = f(
P
v2G
hT
v ) where f is
a neural network which takes the sum of the final hidden
states hT
v . Note the original work only defined the model
for T = 1.
'$&;
'$&=
Update Function:
7"(ℎ$
" , 8$
"93)
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n (22: ,2 )2 7 )2 (,)) +0 ) 2
N D
• 2 1 : 2
2 1 0 ! ℎ#
$
% ∈ ' = ∑# NN(ℎ#
$
)
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n + 0 2 2 E E , - (
[ 1 6 G 2 2 M
• EE6 C6EE C 6E
[ EE6 :G 7 ) !" ℎ$
" , ℎ&
" , '$& = )('$+)ℎ&
"
• )('$+)) N S R '$+ M IL00
[ C 6 :G 7 ) -" ℎ$
" , .$
"/0 = GRU ℎ$
" , .$
"/0
• ,,00 - 1 U
Message Function:
!"(ℎ$
"
, ℎ&4
"
, '$&4
)
Σ
Message Function:
!"(ℎ$
" , ℎ&4
" , '$&4
)
Neural Message Passing f
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
v
u1
u2
h(0)
v
h(0)
u1
h(0)
u2
Neural Message Passing for Quantum Chemistry
time steps and is defined in terms of message functions Mt
and vertex update functions Ut. During the message pass-
ing phase, hidden states ht
v at each node in the graph are
updated based on messages mt+1
v according to
mt+1
v =
X
w2N(v)
Mt(ht
v, ht
w, evw) (1)
ht+1
v = Ut(ht
v, mt+1
v ) (2)
where in the sum, N(v) denotes the neighbors of v in graph
G. The readout phase computes a feature vector for the
whole graph using some readout function R according to
ˆy = R({hT
v | v 2 G}). (3)
The message functions Mt, vertex update functions Ut, and
readout function R are all learned differentiable functions.
R operates on the set of node states and must be invariant to
permutations of the node states in order for the MPNN to be
invariant to graph isomorphism. In what follows, we define
previous models in the literature by specifying the message
function Mt, vertex update function Ut, and readout func-
tion R used. Note one could also learn edge features in
an MPNN by introducing hidden states for all edges in the
graph ht
evw
and updating them analogously to equations 1
and 2. Of the existing MPNNs, only Kearnes et al. (2016)
Recurrent Unit introduced in Cho et al. (2014). This work
used weight tying, so the same update function is used at
each time step t. Finally,
R =
X
v2V
⇣
i(h(T )
v , h0
v)
⌘ ⇣
j(h(T )
v )
⌘
(4)
where i and j are neural networks, and denotes element-
wise multiplication.
Interaction Networks, Battaglia et al. (2016)
This work considered both the case where there is a tar-
get at each node in the graph, and where there is a graph
level target. It also considered the case where there are
node level effects applied at each time step, in such a
case the update function takes as input the concatenation
(hv, xv, mv) where xv is an external vector representing
some outside influence on the vertex v. The message func-
tion M(hv, hw, evw) is a neural network which takes the
concatenation (hv, hw, evw). The vertex update function
U(hv, xv, mv) is a neural network which takes as input
the concatenation (hv, xv, mv). Finally, in the case where
there is a graph level output, R = f(
P
v2G
hT
v ) where f is
a neural network which takes the sum of the final hidden
states hT
v . Note the original work only defined the model
for T = 1.
'$&4
'$&5
Update Function:
-"(ℎ$
" , .$
"/0)
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n + 0 2 2 C C , - (
1 6 E L2 2
• 1 6 E :E 7 )
! ℎ#
$
% ∈ ' = set2set ℎ#
$
% ∈ '
C C G6 C - 1 -.
∗
RM00M SLIN
size of the set, and which is order invariant. In the next sections, we explain such a modification,
which could also be seen as a special case of a Memory Network (Weston et al., 2015) or Neural
Turing Machine (Graves et al., 2014) – with a computation flow as depicted in Figure 1.
4.2 ATTENTION MECHANISMS
Neural models with memories coupled to differentiable addressing mechanism have been success-
fully applied to handwriting generation and recognition (Graves, 2012), machine translation (Bah-
danau et al., 2015a), and more general computation machines (Graves et al., 2014; Weston et al.,
2015). Since we are interested in associative memories we employed a “content” based attention.
This has the property that the vector retrieved from our memory would not change if we randomly
shuffled the memory. This is crucial for proper treatment of the input set X as such. In particular,
our process block based on an attention mechanism uses the following:
qt = LSTM(q⇤
t 1) (3)
ei,t = f(mi, qt) (4)
ai,t =
exp(ei,t)
P
j exp(ej,t)
(5)
rt =
X
i
ai,tmi (6)
q⇤
t = [qt rt] (7)
Read
Process Write
Figure 1: The Read-Process-and-Write model.
where i indexes through each memory vector mi (typically equal to the cardinality of X), qt is
a query vector which allows us to read rt from the memories, f is a function that computes a
single scalar from mi and qt (e.g., a dot product), and LSTM is an LSTM which computes a
recurrent state but which takes no inputs. q⇤
t is the state which this LSTM evolves, and is formed
by concatenating the query qt with the resulting attention readout rt. t is the index which indicates
V ) G6 C - 1
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n
•
•
3
Figure 2: Illustration of SchNet with an architectural overview (left), the interaction block (middle)
and the continuous-filter convolution with filter-generating network (right). The shifted softplus is
defined as ssp(x) = ln(0.5ex
+ 0.5).
: https://www.slideshare.net/KazukiFujikawa/schnet-a-continuousfilter-convolutional-neural-network-for-modeling-quantum-
interactions
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n 2 . 727 . + )2 2 7 .2 ) )2
C
eft), the interaction block (middle)
ork (right). The shifted softplus is
Zi
+ (+ !"# = %" − %#
(a) 1st
interaction block (b) 2nd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte
SchNet trained on molecular dynamics of ethanol. Negative values ar
Filter-generating networks The cfconv layer including its filter-ge
at the right panel of Fig. 2. In order to satisfy the requirements for
we restrict our filters for the cfconv layers to be rotationally invarian
obtained by using interatomic distances
dij = kri rjk
as input for the filter network. Without further processing, the filters w
a neural network after initialization is close to linear. This leads to
training that is hard to overcome. We avoid this by expanding the dista
ek(ri rj) = exp( kdij µkk2
)
located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is
occurring in the data sets are covered by the filters. Due to this addit
filters are less correlated leading to a faster training procedure. Choos
to reducing the resolution of the filter, while restricting the range of
filter size in a usual convolutional layer. An extensive evaluation of th
left for future work. We feed the expanded distances into two dense l
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction
an ethanol molecular dynamics trajectory. We observe how each filter
interatomic distances. This enables its interaction block to update the re
radial environment of each atom. The sequential updates from three in
to construct highly complex many-body representations in the spirit o
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecula
an energy-conserving force model by differentiating the energy mode
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
(Z1, . . . , Zn, r
+ + -. + 2 3 7
+ (+
Zj’ Zj
+ + -. + 2 3 7
+ (+
+
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1 + 1 ,
n 1 12 .7 , 2 1 , ,
p e i Co m
eft), the interaction block (middle)
ork (right). The shifted softplus is
Zi
( ).3+.- ) !"# = %" − %#
(a) 1st
interaction block (b) 2nd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte
SchNet trained on molecular dynamics of ethanol. Negative values ar
Filter-generating networks The cfconv layer including its filter-ge
at the right panel of Fig. 2. In order to satisfy the requirements for
we restrict our filters for the cfconv layers to be rotationally invarian
obtained by using interatomic distances
dij = kri rjk
as input for the filter network. Without further processing, the filters w
a neural network after initialization is close to linear. This leads to
training that is hard to overcome. We avoid this by expanding the dista
ek(ri rj) = exp( kdij µkk2
)
located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is
occurring in the data sets are covered by the filters. Due to this addit
filters are less correlated leading to a faster training procedure. Choos
to reducing the resolution of the filter, while restricting the range of
filter size in a usual convolutional layer. An extensive evaluation of th
left for future work. We feed the expanded distances into two dense l
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction
an ethanol molecular dynamics trajectory. We observe how each filter
interatomic distances. This enables its interaction block to update the re
radial environment of each atom. The sequential updates from three in
to construct highly complex many-body representations in the spirit o
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecula
an energy-conserving force model by differentiating the energy mode
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
(Z1, . . . , Zn, r
-. . 01 .- 2
.3+.-
Zj’ Zj
-. . 01 .- 2
.3+.-
+
2 7 0 7- 2 0 7
'( = 0.1Å, '. = 0.2Å, … '122 = 30Å
4 = 10Å 7+ ( b
!"# hC' d
h n lC h
0 c f
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1 + 1 ,
n 327 2 3 - 7 32 3 7 32 - 32
b a
eft), the interaction block (middle)
ork (right). The shifted softplus is
Zi
) + !"# = %" − %#
(a) 1st
interaction block (b) 2nd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte
SchNet trained on molecular dynamics of ethanol. Negative values ar
Filter-generating networks The cfconv layer including its filter-ge
at the right panel of Fig. 2. In order to satisfy the requirements for
we restrict our filters for the cfconv layers to be rotationally invarian
obtained by using interatomic distances
dij = kri rjk
as input for the filter network. Without further processing, the filters w
a neural network after initialization is close to linear. This leads to
training that is hard to overcome. We avoid this by expanding the dista
ek(ri rj) = exp( kdij µkk2
)
located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is
occurring in the data sets are covered by the filters. Due to this addit
filters are less correlated leading to a faster training procedure. Choos
to reducing the resolution of the filter, while restricting the range of
filter size in a usual convolutional layer. An extensive evaluation of th
left for future work. We feed the expanded distances into two dense l
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction
an ethanol molecular dynamics trajectory. We observe how each filter
interatomic distances. This enables its interaction block to update the re
radial environment of each atom. The sequential updates from three in
to construct highly complex many-body representations in the spirit o
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecula
an energy-conserving force model by differentiating the energy mode
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
(Z1, . . . , Zn, r
+ 2 . -7 + 3-7
) +
Zj’ Zj
+ 2 . -7 + 3-7
) +
+
2 7 0 7- 2 0 7
C ) + 73
c
( 7 (7 32
) + 73
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n 2 . 727 . + )2 2 7 .2 ) )2
C
eft), the interaction block (middle)
ork (right). The shifted softplus is
Zi
+ (+ !"# = %" − %#
(a) 1st
interaction block (b) 2nd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte
SchNet trained on molecular dynamics of ethanol. Negative values ar
Filter-generating networks The cfconv layer including its filter-ge
at the right panel of Fig. 2. In order to satisfy the requirements for
we restrict our filters for the cfconv layers to be rotationally invarian
obtained by using interatomic distances
dij = kri rjk
as input for the filter network. Without further processing, the filters w
a neural network after initialization is close to linear. This leads to
training that is hard to overcome. We avoid this by expanding the dista
ek(ri rj) = exp( kdij µkk2
)
located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is
occurring in the data sets are covered by the filters. Due to this addit
filters are less correlated leading to a faster training procedure. Choos
to reducing the resolution of the filter, while restricting the range of
filter size in a usual convolutional layer. An extensive evaluation of th
left for future work. We feed the expanded distances into two dense l
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction
an ethanol molecular dynamics trajectory. We observe how each filter
interatomic distances. This enables its interaction block to update the re
radial environment of each atom. The sequential updates from three in
to construct highly complex many-body representations in the spirit o
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecula
an energy-conserving force model by differentiating the energy mode
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
(Z1, . . . , Zn, r
+ + -. + 2 3 7
+ (+
Zj’ Zj
+ + -. + 2 3 7
+ (+
+
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n
e D
• I a I c
N
• I B
T
SchNet with an architectural overview (left), the interaction block (middle)
r convolution with filter-generating network (right). The shifted softplus is
0.5ex
+ 0.5).
(a) 1st
interaction block (b) 2nd
interaction block (c) 3rd
interaction block
Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filters in each interaction block of
SchNet trained on molecular dynamics of ethanol. Negative values are blue, positive values are red.
Filter-generating networks The cfconv layer including its filter-generating network are depicted
at the right panel of Fig. 2. In order to satisfy the requirements for modeling molecular energies,
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n
, - S
1
1 N
Figure 2: Illustration of SchNet with an architectural overview (left), the interaction block (middle)
and the continuous-filter convolution with filter-generating network (right). The shifted softplus is
defined as ssp(x) = ln(0.5ex
+ 0.5).
4.1 Architecture
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n 011
a l C [C
• m: aeoh
• m: aeohC : 2 7
• ! C C L
• mi : ] , +
100,000 0.34 0.84 – –
110,462 0.31 – 0.45 0.33
We include the total energy E as well as forces Fi in the training loss to train a neural network that
performs well on both properties:
`( ˆE, (E, F1, . . . , Fn)) = kE ˆEk2
+
⇢
n
nX
i=0
Fi
@ ˆE
@Ri
! 2
. (5)
This kind of loss has been used before for fitting a restricted potential energy surfaces with MLPs [36].
n our experiments, we use ⇢ = 0 in Eq. 5 for pure energy based training and ⇢ = 100 for combined
energy and force training. The value of ⇢ was optimized empirically to account for different scales of
energy and forces.
Due to the relation of energies and forces reflected in the model, we expect to see improved gen-
eralization, however, at a computational cost. As we need to perform a full forward and backward
pass on the energy model to obtain the forces, the resulting force model is twice as deep and, hence,
equires about twice the amount of computation time.
Even though the GDML model captures this relationship between energies and forces, it is explicitly
optimized to predict the force field while the energy prediction is a by-product. Models such as
circular fingerprints [15], molecular graph convolutions or message-passing neural networks[19] for
property prediction across chemical compound space are only concerned with equilibrium molecules,
.e., the special case where the forces are vanishing. They can not be trained with forces in a similar
manner, as they include discontinuities in their predicted potential energy surface caused by discrete
binning or the use of one-hot encoded bond type information.
5 Experiments and results
0.59 0.94 – –
0.34 0.84 – –
0.31 – 0.45 0.33
as well as forces Fi in the training loss to train a neural network that
ies:
. . , Fn)) = kE ˆEk2
+
⇢
n
nX
i=0
Fi
@ ˆE
@Ri
! 2
. (5)
before for fitting a restricted potential energy surfaces with MLPs [36].
= 0 in Eq. 5 for pure energy based training and ⇢ = 100 for combined
value of ⇢ was optimized empirically to account for different scales of
es and forces reflected in the model, we expect to see improved gen-
mputational cost. As we need to perform a full forward and backward
btain the forces, the resulting force model is twice as deep and, hence,
nt of computation time.
l captures this relationship between energies and forces, it is explicitly
e field while the energy prediction is a by-product. Models such as
ecular graph convolutions or message-passing neural networks[19] for
rgy predictions in kcal/mol on the QM9 data set with given
.
DTNN [18] enn-s2s [19] enn-s2s-ens5 [19]
0.94 – –
0.84 – –
– 0.45 0.33
as forces Fi in the training loss to train a neural network that
= kE ˆEk2
+
⇢
n
nX
i=0
Fi
@ ˆE
@Ri
! 2
. (5)
or fitting a restricted potential energy surfaces with MLPs [36].
q. 5 for pure energy based training and ⇢ = 100 for combined
⇢ was optimized empirically to account for different scales of
rces reflected in the model, we expect to see improved gen-
al cost. As we need to perform a full forward and backward
forces, the resulting force model is twice as deep and, hence,
mputation time.
s this relationship between energies and forces, it is explicitly
hile the energy prediction is a by-product. Models such as
aph convolutions or message-passing neural networks[19] for
mpound space are only concerned with equilibrium molecules,
re vanishing. They can not be trained with forces in a similar
in their predicted potential energy surface caused by discrete
bond type information.
ek(ri rj) = exp( kdij µkk )
located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is chosen such that all distances
occurring in the data sets are covered by the filters. Due to this additional non-linearity, the initial
filters are less correlated leading to a faster training procedure. Choosing fewer centers corresponds
to reducing the resolution of the filter, while restricting the range of the centers corresponds to the
filter size in a usual convolutional layer. An extensive evaluation of the impact of these variables is
left for future work. We feed the expanded distances into two dense layers with softplus activations
to compute the filter weight W(ri rj) as shown in Fig. 2 (right).
Fig 3 shows 2d-cuts through generated filters for all three interaction blocks of SchNet trained on
an ethanol molecular dynamics trajectory. We observe how each filter emphasizes certain ranges of
interatomic distances. This enables its interaction block to update the representations according to the
radial environment of each atom. The sequential updates from three interaction blocks allow SchNet
to construct highly complex many-body representations in the spirit of DTNNs [18] while keeping
rotational invariance due to the radial filters.
4.2 Training with energies and forces
As described above, the interatomic forces are related to the molecular energy, so that we can obtain
an energy-conserving force model by differentiating the energy model w.r.t. the atom positions
ˆFi(Z1, . . . , Zn, r1, . . . , rn) =
@ ˆE
@ri
(Z1, . . . , Zn, r1, . . . , rn). (4)
Chmiela et al. [17] pointed out that this leads to an energy-conserving force-field by construction.
As SchNet yields rotationally invariant energy predictions, the force predictions are rotationally
equivariant by construction. The model has to be at least twice differentiable to allow for gradient
descent of the force loss. We chose a shifted softplus ssp(x) = ln(0.5ex
+ 0.5) as non-linearity
throughout the network in order to obtain a smooth potential energy surface. The shifting ensures that
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
+ , + 1
n G1
+ S lM[ nrU U eFca
• U Qm U G uSUnr U i F U
TN
• S r s SI N L
t
• 00 52 :D 07 9 9 5, 9 -
9 9 G9 U] hi
Table 1: Mean absolute errors for energy predictions in kcal/mol on the QM9 data set with given
training set size N. Best model in bold.
N SchNet DTNN [18] enn-s2s [19] enn-s2s-ens5 [19]
50,000 0.59 0.94 – –
100,000 0.34 0.84 – –
110,462 0.31 – 0.45 0.33
We include the total energy E as well as forces Fi in the training loss to train a neural network that
performs well on both properties:
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1 1 2 1 10 7 7
1 , +
n C G N C
• G N C / c
• G N C /
• / a
Merge
Fully-
Connected
Graph
Convolution
Classification
Ligand Protein
Graph
Receptor Protein
Graph
Graph
Convolution
Graph
Convolution
Graph
Convolution
Residue
Representation
Residue Pair
Representation
R1
R2
R3
R1
R2
R3
R1
R2
R3
R1
R2
R3
R1
R2
R3
R1
R1
R1
R2
R2
R2
R3
R3
R3
Figure 2: An overview of the pairwise classification architecture. Each neighborhood of a residue i
proteins is processed using one or more graph convolution layers, with weight sharing between le
network. The activations generated by the convolutional layers are merged by concatenating them, fol
Node
Residue
Conservation /
Composition
Accessible
Surface Area
Residue Depth
Protrusion Index
Edge
Distance
Angle
protein
1: Graph convolution on protein structures. Left: Each residue in a protein is a node in a graph where the
rhood of a node is the set of neighboring nodes in the protein structure; each node has features computed
amino acid sequence and structure, and edges have features describing the relative distance and angle
n residues. Right: Schematic description of the convolution operator which has as its receptive field a set
hboring residues, and produces an activation which is associated with the center residue.
orhood of size k, Fin input features and Fout output features is O(kFinFoutn). Construction of
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
1 1 2 1 10 7 7
1 , +
n : O F C GAE M
/ 3 N
• 3 /
• 3 / /
•
of the neighbors of a node. Our objective is to design convolution operators that can be applied to
graphs without a regular structure, and without imposing a particular order on the neighbors of a
given node. To summarize, we would like to learn a mapping at each node in the graph which has
the form: zi = W (xi, {xn1
, . . . , xnk
}), where {n1, . . . , nk} are the neighbors of node i that define
the receptive field of the convolution, is a non-linear activation function, and W are its learned
parameters; the dependence on the neighboring nodes as a set represents our intention to learn a
function that is order-independent. We present the following two realizations of this operator that
provides the output of a set of filters in a neighborhood of a node of interest that we refer to as the
"center node":
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
xj + b
◆
, (1)
where Ni is the set of neighbors of node i, W C
is the weight matrix associated with the center node,
W N
is the weight matrix associated with neighboring nodes, and b is a vector of biases, one for each
filter. The dimensionality of the weight matrices is determined by the dimensionality of the inputs
and the number of filters. The computational complexity of this operator on a graph with n nodes, a
2
Figure 1: Graph convolution on protein structures. Left: Each residue in a protein is a node in a graph where
neighborhood of a node is the set of neighboring nodes in the protein structure; each node has features comp
from its amino acid sequence and structure, and edges have features describing the relative distance and a
between residues. Right: Schematic description of the convolution operator which has as its receptive field
of neighboring residues, and produces an activation which is associated with the center residue.
neighborhood of size k, Fin input features and Fout output features is O(kFinFoutn). Constructio
the neighborhood is straightforward using a preprocessing step that takes O(n2
log n).
In order to provide for some differentiation between neighbors, we incorporate features on the ed
between each neighbor and the center node as follows:
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
xj +
1
|Ni|
X
j2Ni
W E
Aij + b
◆
,
where W E
is the weight matrix associated with edge features.
For comparison with order-independent methods we propose an order-dependent method, wh
order is determined by distance from the center node. In this method each neighbor has unique we
matrices for nodes and edges:
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
j xj +
1
|Ni|
X
j2Ni
W E
j Aij + b
◆
.
Here W N
j /W E
j are the weight matrices associated with the jth
node or the edges connecting to the
nodes, respectively. This operator is inspired by the PATCHY-SAN method of Niepert et al. [16].
more flexible than the order-independent convolutional operators, allowing the learning of distinct
between neighbors at the cost of significantly more parameters.
Multiple layers of these graph convolution operators can be used, and this will have the ef
of learning features that characterize the graph at increasing levels of abstraction, and will
allow information to propagate through the graph, thereby integrating information across region
increasing size. Furthermore, these operators are rotation-invariant if the features have this prop
In convolutional networks, inputs are often downsampled based on the size and stride of the recep
field. It is also common to use pooling to further reduce the size of the input. Our graph opera
on the other hand maintain the structure of the graph, which is necessary for the protein interf
prediction problem, where we classify pairs of nodes from different graphs, rather than en
graphs. Using convolutional architectures that use only convolutional layers without downsamplin
common practice in the area of graph convolutional networks, especially if classification is perform
neighborhood of size k, Fin input features and Fout output features is O(kFinFoutn). Construction of
the neighborhood is straightforward using a preprocessing step that takes O(n2
log n).
In order to provide for some differentiation between neighbors, we incorporate features on the edges
between each neighbor and the center node as follows:
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
xj +
1
|Ni|
X
j2Ni
W E
Aij + b
◆
, (2)
where W E
is the weight matrix associated with edge features.
For comparison with order-independent methods we propose an order-dependent method, where
order is determined by distance from the center node. In this method each neighbor has unique weight
matrices for nodes and edges:
zi =
✓
W C
xi +
1
|Ni|
X
j2Ni
W N
j xj +
1
|Ni|
X
j2Ni
W E
j Aij + b
◆
. (3)
Here W N
j /W E
j are the weight matrices associated with the jth
node or the edges connecting to the jth
nodes, respectively. This operator is inspired by the PATCHY-SAN method of Niepert et al. [16]. It is
more flexible than the order-independent convolutional operators, allowing the learning of distinctions
between neighbors at the cost of significantly more parameters.
Multiple layers of these graph convolution operators can be used, and this will have the effect
of learning features that characterize the graph at increasing levels of abstraction, and will also
allow information to propagate through the graph, thereby integrating information across regions of
increasing size. Furthermore, these operators are rotation-invariant if the features have this property.
In convolutional networks, inputs are often downsampled based on the size and stride of the receptive
field. It is also common to use pooling to further reduce the size of the input. Our graph operators
on the other hand maintain the structure of the graph, which is necessary for the protein interface
prediction problem, where we classify pairs of nodes from different graphs, rather than entire
graphs. Using convolutional architectures that use only convolutional layers without downsampling is
common practice in the area of graph convolutional networks, especially if classification is performed
at the node or edge level. This practice has support from the success of networks without pooling
layers in the realm of object recognition [23]. The downside of not downsampling is higher memory
and computational costs.
Related work. Several authors have recently proposed graph convolutional operators that generalize
Method Convolutional Layers
1 2 3 4
No Convolution 0.812 (0.007) 0.810 (0.006) 0.808 (0.006) 0.796 (0.006)
Diffusion (DCNN) (2 hops) [5] 0.790 (0.014) – – –
Diffusion (DCNN) (5 hops) [5]) 0.828 (0.018) – – –
Single Weight Matrix (MFN [9]) 0.865 (0.007) 0.871 (0.013) 0.873 (0.017) 0.869 (0.017)
Node Average (Equation (1)) 0.864 (0.007) 0.882 (0.007) 0.891 (0.005) 0.889 (0.005)
Node and Edge Average (Equation (2)) 0.876 (0.005) 0.898 (0.005) 0.895 (0.006) 0.889 (0.007)
DTNN [21] 0.867 (0.007) 0.880 (0.007) 0.882 (0.008) 0.873 (0.012)
Order Dependent (Equation (3)) 0.854 (0.004) 0.873 (0.005) 0.891 (0.004) 0.889 (0.008)
Table 2: Median area under the receiver operating characteristic curve (AUC) across all complexes in the
test set for various graph convolutional methods. Results shown are the average and standard deviation over
ten runs with different random seeds. Networks have the following number of filters for 1, 2, 3, and 4 layers
before merging, respectively: (256), (256, 512), (256, 256, 512), (256, 256, 512, 512). The exception is the
DTNN method, which by necessity produces an output which is has the same dimensionality as its input. Unlike
the other methods, diffusion convolution performed best with an RBF with a standard deviation of 2Å. After
merging, all networks have a dense layer with 512 hidden units followed by a binary classification layer. Bold
faced values indicate best performance for each method.
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n S . .
o a ib a e G
o a e c : d
•
o l a . . P
•
o n .
o g e .
o hm .
Figure 1: Scene graphs are defined by the objects in an image (vertices) and their interactions (edges).
The ability to express information about the connections between objects make scene graphs a useful
representation for many computer vision tasks including captioning and visual question answering.
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n C
2
1
1
1 2
2 . N
Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two
heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are
extracted from the pixel locations of top activations and fed through fully connected networks to
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n
e 2 3,
•
D 3, c I
I 3 1
c I , I
I 3 , .d
Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two
heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are
extracted from the pixel locations of top activations and fed through fully connected networks to
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n I
/ 2
•
1
•
2 D 2 D
•
D L
not indicate which vertex serves as the source and which serves as the destination, nor does it
disambiguate between pairs of vertices that happen to share the same midpoint.
To train the network to produce a coherent set of embeddings we build off of the loss penalty used in
[20]. During training, we have a ground truth set of annotations defining the unique objects in the
scene and the edges between these objects. This allows us to enforce two penalties: that an edge points
to a vertex by matching its output embedding as closely as possible, and that the embedding vectors
produced for each vertex are sufficiently different. We think of the first as “pulling together” all
references to a single vertex, and the second as “pushing apart” the references to different individual
vertices.
We consider an embedding hi 2 Rd
produced for a vertex vi 2 V . All edges that connect to this
vertex produce a set of embeddings h0
ik, k = 1, ..., Ki where Ki is the total number of references to
that vertex. Given an image with n objects the loss to “pull together” these embeddings is:
Lpull =
1
Pn
i=1 Ki
nX
i=1
KiX
k=1
(hi h0
ik)2
To “push apart” embeddings across different vertices we first used the penalty described in [20],
but experienced difficulty with convergence. We tested alternatives and the most reliable loss was a
margin-based penalty similar to [9]:
Lpush =
n 1X
i=1
nX
j=i+1
max(0, m ||hi hj||)
Intuitively, Lpush is at its highest the closer hi and hj are to each other. The penalty drops off sharply
as the distance between hi and hj grows, eventually hitting zero once the distance is greater than a
given margin m. On the flip side, for some edge connected to a vertex vi, the loss Lpull will quickly
grow the further its reference embedding h0
i is from hi.
The two penalties are weighted equally leaving a final associative embedding loss of Lpull + Lpush.
In this work, we use m = 8 and d = 8. Convergence of the network improves greatly after increasing
the dimension d of tags up from 1 as used in [20].
Once the network is trained with this loss, full construction of the graph can be performed with a
trivial postprocessing step. The network produces a pool of vertex and edge detections. For every
edge, we look at the source and destination embeddings and match them to the closest embedding
amongst the detected vertices. Multiple edges may have the same source and target vertices, vs and
vt, and it is also possible for vs to equal vt.
Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two
heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are
extracted from the pixel locations of top activations and fed through fully connected networks to
To train the network to produce a coherent set of embeddings we build off of the loss penalty used in
[20]. During training, we have a ground truth set of annotations defining the unique objects in the
scene and the edges between these objects. This allows us to enforce two penalties: that an edge points
to a vertex by matching its output embedding as closely as possible, and that the embedding vectors
produced for each vertex are sufficiently different. We think of the first as “pulling together” all
references to a single vertex, and the second as “pushing apart” the references to different individual
vertices.
We consider an embedding hi 2 Rd
produced for a vertex vi 2 V . All edges that connect to this
vertex produce a set of embeddings h0
ik, k = 1, ..., Ki where Ki is the total number of references to
that vertex. Given an image with n objects the loss to “pull together” these embeddings is:
Lpull =
1
Pn
i=1 Ki
nX
i=1
KiX
k=1
(hi h0
ik)2
To “push apart” embeddings across different vertices we first used the penalty described in [20],
but experienced difficulty with convergence. We tested alternatives and the most reliable loss was a
margin-based penalty similar to [9]:
Lpush =
n 1X
i=1
nX
j=i+1
max(0, m ||hi hj||)
Intuitively, Lpush is at its highest the closer hi and hj are to each other. The penalty drops off sharply
as the distance between hi and hj grows, eventually hitting zero once the distance is greater than a
given margin m. On the flip side, for some edge connected to a vertex vi, the loss Lpull will quickly
grow the further its reference embedding h0
i is from hi.
The two penalties are weighted equally leaving a final associative embedding loss of Lpull + Lpush.
In this work, we use m = 8 and d = 8. Convergence of the network improves greatly after increasing
the dimension d of tags up from 1 as used in [20].
Once the network is trained with this loss, full construction of the graph can be performed with a
trivial postprocessing step. The network produces a pool of vertex and edge detections. For every
edge, we look at the source and destination embeddings and match them to the closest embedding
amongst the detected vertices. Multiple edges may have the same source and target vertices, vs and
vt, and it is also possible for vs to equal vt.
5
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
2 A 0 2 7 211
, 2 2 +
n S
e a@ G P
• d@ C R b c :
• d C b c :
• 3 d G C c :
SGGen (no RPN) SGGen (w/ RPN) SGCls PredCls
R@50 R@100 R@50 R@100 R@50 R@100 R@50 R@100
Lu et al. [18] – – 0.3 0.5 11.8 14.1 27.9 35.0
Xu et al. [26] – – 3.4 4.2 21.7 24.4 44.8 53.0
Our model 6.7 7.8 9.7 11.3 26.5 30.0 68.0 75.2
Table 1: Results on Visual Genome
Figure 3: Predictions on Visual Genome. In the top row, the network must produce all object and
relationship detections directly from the image. The second row includes examples from an easier
version of the task where object detections are provided. Relationships outlined in green correspond
to predictions that correctly matched to a ground truth annotation.
the course of training. We set so = 3 and sr = 6 which is sufficient to completely accommodate the
detection annotations for all but a small fraction of cases.
&
@ c
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n 2FV RIOS 7F NJNH
Y NSO O" M" FS L FS LF NJNH VJSI MFMO HMFNSF NF L
NFSVO KR
Y JN LR" : JOL" FS L SDIJNH NFSVO KR O ONF RIOS LF NJNH NDFR JN
9F L 4N O M SJON ODFRRJNH RSFMR
Y NFLL" 5 KF" 6F JN VF RK " N JDI AFMFL OSOS PJD L NFSVO KR O
FV RIOS LF NJNH NDFR JN 9F L 4N O M SJON ODFRRJNH RSFMR
,
Y J" DIJN" N 3 HO 7 ODIFLLF :PSJMJX SJON R MO FL O FV RIOS
LF NJNH
Y 2JNN" 0IFLRF " JFSF CCFFL" N F HF 7F JNF O FL HNORSJD MFS
LF NJNH O RS PS SJON O FFP NFSVO KR J P FP JNS
J . , ( () ,
Y 1 N" N" FS L :NF RIOS JMJS SJON LF NJNH NDFR JN NF L
JN O M SJON P ODFRRJNH R RSFMR ,
Y 2JNN" 0IFLRF " FS L :NF RIOS JR L JMJS SJON LF NJNH J MFS LF NJNH
J P FP JNS J . , - )- ,
Copyright (C) DeNA Co.,Ltd. All Rights Reserved.
n 3P NF 0 LT JS G L
3GJKCP" 5SQ GL" C J 8CSP J KCQQ C N QQGL D P S L SK AFCKGQ P
P GT NPCNPGL P GT. )
1STCL S " 1 TG " C J 0 LT JS G L J LC PIQ L P NFQ D P JC PLGL
K JCASJ P DGL CPNPGL Q T LACQ GL LCSP J GLD PK G L NP ACQQGL Q Q CKQ
7G" S G " C J 3 C P NF QC SCLAC LCSP J LC PIQ P GT NPCNPGL
P GT. )-(
AFY " PGQ D " C J :S L SK AFCKGA J GLQG F Q DP K CCN CLQ P
LCSP J LC PIQ 8 SPC A KKSLGA G LQ , . (,-
AFY " PGQ D" C J AF8C . A L GLS SQ DGJ CP A LT JS G L J LCSP J
LC PI D P K CJGL S L SK GL CP A G LQ T LACQ GL 8CSP J
4LD PK G L 9P ACQQGL Q CKQ
2 S " JC " C J 9P CGL 4L CPD AC 9PC GA G L SQGL 3P NF 0 LT JS G L J
8C PIQ T LACQ GL 8CSP J 4LD PK G L 9P ACQQGL Q CKQ
8C CJJ" JC L P " L 5G 1CL 9G CJQ P NFQ QQ AG GTC
CK C GL T LACQ GL 8CSP J 4LD PK G L 9P ACQQGL Q CKQ

More Related Content

What's hot

Lecture 06 marco aurelio ranzato - deep learning
Lecture 06   marco aurelio ranzato - deep learningLecture 06   marco aurelio ranzato - deep learning
Lecture 06 marco aurelio ranzato - deep learningmustafa sarac
 
ICML2013読み会 Large-Scale Learning with Less RAM via Randomization
ICML2013読み会 Large-Scale Learning with Less RAM via RandomizationICML2013読み会 Large-Scale Learning with Less RAM via Randomization
ICML2013読み会 Large-Scale Learning with Less RAM via RandomizationHidekazu Oiwa
 
Section5 Rbf
Section5 RbfSection5 Rbf
Section5 Rbfkylin
 
Neural Processes Family
Neural Processes FamilyNeural Processes Family
Neural Processes FamilyKota Matsui
 
NIPS読み会2013: One-shot learning by inverting a compositional causal process
NIPS読み会2013: One-shot learning by inverting  a compositional causal processNIPS読み会2013: One-shot learning by inverting  a compositional causal process
NIPS読み会2013: One-shot learning by inverting a compositional causal processnozyh
 
All Pair Shortest Path Algorithm – Parallel Implementation and Analysis
All Pair Shortest Path Algorithm – Parallel Implementation and AnalysisAll Pair Shortest Path Algorithm – Parallel Implementation and Analysis
All Pair Shortest Path Algorithm – Parallel Implementation and AnalysisInderjeet Singh
 
Data-Driven Recommender Systems
Data-Driven Recommender SystemsData-Driven Recommender Systems
Data-Driven Recommender Systemsrecsysfr
 
(Kpi summer school 2015) theano tutorial part2
(Kpi summer school 2015) theano tutorial part2(Kpi summer school 2015) theano tutorial part2
(Kpi summer school 2015) theano tutorial part2Serhii Havrylov
 
Learning stochastic neural networks with Chainer
Learning stochastic neural networks with ChainerLearning stochastic neural networks with Chainer
Learning stochastic neural networks with ChainerSeiya Tokui
 
Graph Neural Network in practice
Graph Neural Network in practiceGraph Neural Network in practice
Graph Neural Network in practicetuxette
 
Review : Prototype Mixture Models for Few-shot Semantic Segmentation
Review : Prototype Mixture Models for Few-shot Semantic SegmentationReview : Prototype Mixture Models for Few-shot Semantic Segmentation
Review : Prototype Mixture Models for Few-shot Semantic SegmentationDongmin Choi
 
Conditional neural processes
Conditional neural processesConditional neural processes
Conditional neural processesKazuki Fujikawa
 
論文紹介 Fast imagetagging
論文紹介 Fast imagetagging論文紹介 Fast imagetagging
論文紹介 Fast imagetaggingTakashi Abe
 
Gradient Estimation Using Stochastic Computation Graphs
Gradient Estimation Using Stochastic Computation GraphsGradient Estimation Using Stochastic Computation Graphs
Gradient Estimation Using Stochastic Computation GraphsYoonho Lee
 
Neural tool box
Neural tool boxNeural tool box
Neural tool boxMohan Raj
 
"Deep Learning" Chap.6 Convolutional Neural Net
"Deep Learning" Chap.6 Convolutional Neural Net"Deep Learning" Chap.6 Convolutional Neural Net
"Deep Learning" Chap.6 Convolutional Neural NetKen'ichi Matsui
 
Safety Verification of Deep Neural Networks_.pdf
Safety Verification of Deep Neural Networks_.pdfSafety Verification of Deep Neural Networks_.pdf
Safety Verification of Deep Neural Networks_.pdfPolytechnique Montréal
 

What's hot (18)

Lecture 06 marco aurelio ranzato - deep learning
Lecture 06   marco aurelio ranzato - deep learningLecture 06   marco aurelio ranzato - deep learning
Lecture 06 marco aurelio ranzato - deep learning
 
ICML2013読み会 Large-Scale Learning with Less RAM via Randomization
ICML2013読み会 Large-Scale Learning with Less RAM via RandomizationICML2013読み会 Large-Scale Learning with Less RAM via Randomization
ICML2013読み会 Large-Scale Learning with Less RAM via Randomization
 
Section5 Rbf
Section5 RbfSection5 Rbf
Section5 Rbf
 
Neural Processes Family
Neural Processes FamilyNeural Processes Family
Neural Processes Family
 
NIPS読み会2013: One-shot learning by inverting a compositional causal process
NIPS読み会2013: One-shot learning by inverting  a compositional causal processNIPS読み会2013: One-shot learning by inverting  a compositional causal process
NIPS読み会2013: One-shot learning by inverting a compositional causal process
 
All Pair Shortest Path Algorithm – Parallel Implementation and Analysis
All Pair Shortest Path Algorithm – Parallel Implementation and AnalysisAll Pair Shortest Path Algorithm – Parallel Implementation and Analysis
All Pair Shortest Path Algorithm – Parallel Implementation and Analysis
 
Data-Driven Recommender Systems
Data-Driven Recommender SystemsData-Driven Recommender Systems
Data-Driven Recommender Systems
 
(Kpi summer school 2015) theano tutorial part2
(Kpi summer school 2015) theano tutorial part2(Kpi summer school 2015) theano tutorial part2
(Kpi summer school 2015) theano tutorial part2
 
Learning stochastic neural networks with Chainer
Learning stochastic neural networks with ChainerLearning stochastic neural networks with Chainer
Learning stochastic neural networks with Chainer
 
Graph Neural Network in practice
Graph Neural Network in practiceGraph Neural Network in practice
Graph Neural Network in practice
 
Review : Prototype Mixture Models for Few-shot Semantic Segmentation
Review : Prototype Mixture Models for Few-shot Semantic SegmentationReview : Prototype Mixture Models for Few-shot Semantic Segmentation
Review : Prototype Mixture Models for Few-shot Semantic Segmentation
 
Conditional neural processes
Conditional neural processesConditional neural processes
Conditional neural processes
 
論文紹介 Fast imagetagging
論文紹介 Fast imagetagging論文紹介 Fast imagetagging
論文紹介 Fast imagetagging
 
Gradient Estimation Using Stochastic Computation Graphs
Gradient Estimation Using Stochastic Computation GraphsGradient Estimation Using Stochastic Computation Graphs
Gradient Estimation Using Stochastic Computation Graphs
 
Neural tool box
Neural tool boxNeural tool box
Neural tool box
 
"Deep Learning" Chap.6 Convolutional Neural Net
"Deep Learning" Chap.6 Convolutional Neural Net"Deep Learning" Chap.6 Convolutional Neural Net
"Deep Learning" Chap.6 Convolutional Neural Net
 
Getting started with image processing using Matlab
Getting started with image processing using MatlabGetting started with image processing using Matlab
Getting started with image processing using Matlab
 
Safety Verification of Deep Neural Networks_.pdf
Safety Verification of Deep Neural Networks_.pdfSafety Verification of Deep Neural Networks_.pdf
Safety Verification of Deep Neural Networks_.pdf
 

Similar to NIPS2017 Few-shot Learning and Graph Convolution

DeepXplore: Automated Whitebox Testing of Deep Learning
DeepXplore: Automated Whitebox Testing of Deep LearningDeepXplore: Automated Whitebox Testing of Deep Learning
DeepXplore: Automated Whitebox Testing of Deep LearningMasahiro Sakai
 
FPGA Implementation of A New Chien Search Block for Reed-Solomon Codes RS (25...
FPGA Implementation of A New Chien Search Block for Reed-Solomon Codes RS (25...FPGA Implementation of A New Chien Search Block for Reed-Solomon Codes RS (25...
FPGA Implementation of A New Chien Search Block for Reed-Solomon Codes RS (25...IJERA Editor
 
Metaheuristic Tuning of Type-II Fuzzy Inference System for Data Mining
Metaheuristic Tuning of Type-II Fuzzy Inference System for Data MiningMetaheuristic Tuning of Type-II Fuzzy Inference System for Data Mining
Metaheuristic Tuning of Type-II Fuzzy Inference System for Data MiningVarun Ojha
 
MATHEMATICAL MODELING OF COMPLEX REDUNDANT SYSTEM UNDER HEAD-OF-LINE REPAIR
MATHEMATICAL MODELING OF COMPLEX REDUNDANT SYSTEM UNDER HEAD-OF-LINE REPAIRMATHEMATICAL MODELING OF COMPLEX REDUNDANT SYSTEM UNDER HEAD-OF-LINE REPAIR
MATHEMATICAL MODELING OF COMPLEX REDUNDANT SYSTEM UNDER HEAD-OF-LINE REPAIREditor IJMTER
 
Project 2: Baseband Data Communication
Project 2: Baseband Data CommunicationProject 2: Baseband Data Communication
Project 2: Baseband Data CommunicationDanish Bangash
 
Higher Order Fused Regularization for Supervised Learning with Grouped Parame...
Higher Order Fused Regularization for Supervised Learning with Grouped Parame...Higher Order Fused Regularization for Supervised Learning with Grouped Parame...
Higher Order Fused Regularization for Supervised Learning with Grouped Parame...Koh Takeuchi
 
nlp dl 1.pdf
nlp dl 1.pdfnlp dl 1.pdf
nlp dl 1.pdfnyomans1
 
Real Time System Identification of Speech Signal Using Tms320c6713
Real Time System Identification of Speech Signal Using Tms320c6713Real Time System Identification of Speech Signal Using Tms320c6713
Real Time System Identification of Speech Signal Using Tms320c6713IOSRJVSP
 
Simultaneous State and Actuator Fault Estimation With Fuzzy Descriptor PMID a...
Simultaneous State and Actuator Fault Estimation With Fuzzy Descriptor PMID a...Simultaneous State and Actuator Fault Estimation With Fuzzy Descriptor PMID a...
Simultaneous State and Actuator Fault Estimation With Fuzzy Descriptor PMID a...Waqas Tariq
 
【論文紹介】Relay: A New IR for Machine Learning Frameworks
【論文紹介】Relay: A New IR for Machine Learning Frameworks【論文紹介】Relay: A New IR for Machine Learning Frameworks
【論文紹介】Relay: A New IR for Machine Learning FrameworksTakeo Imai
 
Python for Chemistry
Python for ChemistryPython for Chemistry
Python for Chemistryguest5929fa7
 
Python for Chemistry
Python for ChemistryPython for Chemistry
Python for Chemistrybaoilleach
 
Mining of time series data base using fuzzy neural information systems
Mining of time series data base using fuzzy neural information systemsMining of time series data base using fuzzy neural information systems
Mining of time series data base using fuzzy neural information systemsDr.MAYA NAYAK
 
Regression and Classification with R
Regression and Classification with RRegression and Classification with R
Regression and Classification with RYanchang Zhao
 
Modal Analysis Basic Theory
Modal Analysis Basic TheoryModal Analysis Basic Theory
Modal Analysis Basic TheoryYuanCheng38
 
A Simple Communication System Design Lab #4 with MATLAB Simulink
A Simple Communication System Design Lab #4 with MATLAB SimulinkA Simple Communication System Design Lab #4 with MATLAB Simulink
A Simple Communication System Design Lab #4 with MATLAB SimulinkJaewook. Kang
 

Similar to NIPS2017 Few-shot Learning and Graph Convolution (20)

DeepXplore: Automated Whitebox Testing of Deep Learning
DeepXplore: Automated Whitebox Testing of Deep LearningDeepXplore: Automated Whitebox Testing of Deep Learning
DeepXplore: Automated Whitebox Testing of Deep Learning
 
FPGA Implementation of A New Chien Search Block for Reed-Solomon Codes RS (25...
FPGA Implementation of A New Chien Search Block for Reed-Solomon Codes RS (25...FPGA Implementation of A New Chien Search Block for Reed-Solomon Codes RS (25...
FPGA Implementation of A New Chien Search Block for Reed-Solomon Codes RS (25...
 
Metaheuristic Tuning of Type-II Fuzzy Inference System for Data Mining
Metaheuristic Tuning of Type-II Fuzzy Inference System for Data MiningMetaheuristic Tuning of Type-II Fuzzy Inference System for Data Mining
Metaheuristic Tuning of Type-II Fuzzy Inference System for Data Mining
 
MATHEMATICAL MODELING OF COMPLEX REDUNDANT SYSTEM UNDER HEAD-OF-LINE REPAIR
MATHEMATICAL MODELING OF COMPLEX REDUNDANT SYSTEM UNDER HEAD-OF-LINE REPAIRMATHEMATICAL MODELING OF COMPLEX REDUNDANT SYSTEM UNDER HEAD-OF-LINE REPAIR
MATHEMATICAL MODELING OF COMPLEX REDUNDANT SYSTEM UNDER HEAD-OF-LINE REPAIR
 
Project 2: Baseband Data Communication
Project 2: Baseband Data CommunicationProject 2: Baseband Data Communication
Project 2: Baseband Data Communication
 
Higher Order Fused Regularization for Supervised Learning with Grouped Parame...
Higher Order Fused Regularization for Supervised Learning with Grouped Parame...Higher Order Fused Regularization for Supervised Learning with Grouped Parame...
Higher Order Fused Regularization for Supervised Learning with Grouped Parame...
 
nlp dl 1.pdf
nlp dl 1.pdfnlp dl 1.pdf
nlp dl 1.pdf
 
Real Time System Identification of Speech Signal Using Tms320c6713
Real Time System Identification of Speech Signal Using Tms320c6713Real Time System Identification of Speech Signal Using Tms320c6713
Real Time System Identification of Speech Signal Using Tms320c6713
 
Perm winter school 2014.01.31
Perm winter school 2014.01.31Perm winter school 2014.01.31
Perm winter school 2014.01.31
 
Simultaneous State and Actuator Fault Estimation With Fuzzy Descriptor PMID a...
Simultaneous State and Actuator Fault Estimation With Fuzzy Descriptor PMID a...Simultaneous State and Actuator Fault Estimation With Fuzzy Descriptor PMID a...
Simultaneous State and Actuator Fault Estimation With Fuzzy Descriptor PMID a...
 
【論文紹介】Relay: A New IR for Machine Learning Frameworks
【論文紹介】Relay: A New IR for Machine Learning Frameworks【論文紹介】Relay: A New IR for Machine Learning Frameworks
【論文紹介】Relay: A New IR for Machine Learning Frameworks
 
3rd Semester Computer Science and Engineering (ACU - 2022) Question papers
3rd Semester Computer Science and Engineering  (ACU - 2022) Question papers3rd Semester Computer Science and Engineering  (ACU - 2022) Question papers
3rd Semester Computer Science and Engineering (ACU - 2022) Question papers
 
Python for Chemistry
Python for ChemistryPython for Chemistry
Python for Chemistry
 
Python for Chemistry
Python for ChemistryPython for Chemistry
Python for Chemistry
 
1st Semester M Tech: Computer Science and Engineering (Jun-2016) Question Pa...
1st  Semester M Tech: Computer Science and Engineering (Jun-2016) Question Pa...1st  Semester M Tech: Computer Science and Engineering (Jun-2016) Question Pa...
1st Semester M Tech: Computer Science and Engineering (Jun-2016) Question Pa...
 
Mining of time series data base using fuzzy neural information systems
Mining of time series data base using fuzzy neural information systemsMining of time series data base using fuzzy neural information systems
Mining of time series data base using fuzzy neural information systems
 
Regression and Classification with R
Regression and Classification with RRegression and Classification with R
Regression and Classification with R
 
Modal Analysis Basic Theory
Modal Analysis Basic TheoryModal Analysis Basic Theory
Modal Analysis Basic Theory
 
A Simple Communication System Design Lab #4 with MATLAB Simulink
A Simple Communication System Design Lab #4 with MATLAB SimulinkA Simple Communication System Design Lab #4 with MATLAB Simulink
A Simple Communication System Design Lab #4 with MATLAB Simulink
 
In gate-2016-paper
In gate-2016-paperIn gate-2016-paper
In gate-2016-paper
 

More from Kazuki Fujikawa

Stanford Covid Vaccine 2nd place solution
Stanford Covid Vaccine 2nd place solutionStanford Covid Vaccine 2nd place solution
Stanford Covid Vaccine 2nd place solutionKazuki Fujikawa
 
BMS Molecular Translation 3rd place solution
BMS Molecular Translation 3rd place solutionBMS Molecular Translation 3rd place solution
BMS Molecular Translation 3rd place solutionKazuki Fujikawa
 
Kaggle参加報告: Champs Predicting Molecular Properties
Kaggle参加報告: Champs Predicting Molecular PropertiesKaggle参加報告: Champs Predicting Molecular Properties
Kaggle参加報告: Champs Predicting Molecular PropertiesKazuki Fujikawa
 
Kaggle参加報告: Quora Insincere Questions Classification
Kaggle参加報告: Quora Insincere Questions ClassificationKaggle参加報告: Quora Insincere Questions Classification
Kaggle参加報告: Quora Insincere Questions ClassificationKazuki Fujikawa
 
Ordered neurons integrating tree structures into recurrent neural networks
Ordered neurons integrating tree structures into recurrent neural networksOrdered neurons integrating tree structures into recurrent neural networks
Ordered neurons integrating tree structures into recurrent neural networksKazuki Fujikawa
 
A closer look at few shot classification
A closer look at few shot classificationA closer look at few shot classification
A closer look at few shot classificationKazuki Fujikawa
 
Graph convolutional policy network for goal directed molecular graph generation
Graph convolutional policy network for goal directed molecular graph generationGraph convolutional policy network for goal directed molecular graph generation
Graph convolutional policy network for goal directed molecular graph generationKazuki Fujikawa
 
Matrix capsules with em routing
Matrix capsules with em routingMatrix capsules with em routing
Matrix capsules with em routingKazuki Fujikawa
 
Matching networks for one shot learning
Matching networks for one shot learningMatching networks for one shot learning
Matching networks for one shot learningKazuki Fujikawa
 
DeNAにおける機械学習・深層学習活用
DeNAにおける機械学習・深層学習活用DeNAにおける機械学習・深層学習活用
DeNAにおける機械学習・深層学習活用Kazuki Fujikawa
 

More from Kazuki Fujikawa (12)

Stanford Covid Vaccine 2nd place solution
Stanford Covid Vaccine 2nd place solutionStanford Covid Vaccine 2nd place solution
Stanford Covid Vaccine 2nd place solution
 
BMS Molecular Translation 3rd place solution
BMS Molecular Translation 3rd place solutionBMS Molecular Translation 3rd place solution
BMS Molecular Translation 3rd place solution
 
ACL2020 best papers
ACL2020 best papersACL2020 best papers
ACL2020 best papers
 
Kaggle参加報告: Champs Predicting Molecular Properties
Kaggle参加報告: Champs Predicting Molecular PropertiesKaggle参加報告: Champs Predicting Molecular Properties
Kaggle参加報告: Champs Predicting Molecular Properties
 
NLP@ICLR2019
NLP@ICLR2019NLP@ICLR2019
NLP@ICLR2019
 
Kaggle参加報告: Quora Insincere Questions Classification
Kaggle参加報告: Quora Insincere Questions ClassificationKaggle参加報告: Quora Insincere Questions Classification
Kaggle参加報告: Quora Insincere Questions Classification
 
Ordered neurons integrating tree structures into recurrent neural networks
Ordered neurons integrating tree structures into recurrent neural networksOrdered neurons integrating tree structures into recurrent neural networks
Ordered neurons integrating tree structures into recurrent neural networks
 
A closer look at few shot classification
A closer look at few shot classificationA closer look at few shot classification
A closer look at few shot classification
 
Graph convolutional policy network for goal directed molecular graph generation
Graph convolutional policy network for goal directed molecular graph generationGraph convolutional policy network for goal directed molecular graph generation
Graph convolutional policy network for goal directed molecular graph generation
 
Matrix capsules with em routing
Matrix capsules with em routingMatrix capsules with em routing
Matrix capsules with em routing
 
Matching networks for one shot learning
Matching networks for one shot learningMatching networks for one shot learning
Matching networks for one shot learning
 
DeNAにおける機械学習・深層学習活用
DeNAにおける機械学習・深層学習活用DeNAにおける機械学習・深層学習活用
DeNAにおける機械学習・深層学習活用
 

Recently uploaded

Smarteg dropshipping via API with DroFx.pptx
Smarteg dropshipping via API with DroFx.pptxSmarteg dropshipping via API with DroFx.pptx
Smarteg dropshipping via API with DroFx.pptxolyaivanovalion
 
Digital Advertising Lecture for Advanced Digital & Social Media Strategy at U...
Digital Advertising Lecture for Advanced Digital & Social Media Strategy at U...Digital Advertising Lecture for Advanced Digital & Social Media Strategy at U...
Digital Advertising Lecture for Advanced Digital & Social Media Strategy at U...Valters Lauzums
 
Call Girls Indiranagar Just Call 👗 7737669865 👗 Top Class Call Girl Service B...
Call Girls Indiranagar Just Call 👗 7737669865 👗 Top Class Call Girl Service B...Call Girls Indiranagar Just Call 👗 7737669865 👗 Top Class Call Girl Service B...
Call Girls Indiranagar Just Call 👗 7737669865 👗 Top Class Call Girl Service B...amitlee9823
 
ELKO dropshipping via API with DroFx.pptx
ELKO dropshipping via API with DroFx.pptxELKO dropshipping via API with DroFx.pptx
ELKO dropshipping via API with DroFx.pptxolyaivanovalion
 
April 2024 - Crypto Market Report's Analysis
April 2024 - Crypto Market Report's AnalysisApril 2024 - Crypto Market Report's Analysis
April 2024 - Crypto Market Report's Analysismanisha194592
 
Ravak dropshipping via API with DroFx.pptx
Ravak dropshipping via API with DroFx.pptxRavak dropshipping via API with DroFx.pptx
Ravak dropshipping via API with DroFx.pptxolyaivanovalion
 
FESE Capital Markets Fact Sheet 2024 Q1.pdf
FESE Capital Markets Fact Sheet 2024 Q1.pdfFESE Capital Markets Fact Sheet 2024 Q1.pdf
FESE Capital Markets Fact Sheet 2024 Q1.pdfMarinCaroMartnezBerg
 
Cheap Rate Call girls Sarita Vihar Delhi 9205541914 shot 1500 night
Cheap Rate Call girls Sarita Vihar Delhi 9205541914 shot 1500 nightCheap Rate Call girls Sarita Vihar Delhi 9205541914 shot 1500 night
Cheap Rate Call girls Sarita Vihar Delhi 9205541914 shot 1500 nightDelhi Call girls
 
Week-01-2.ppt BBB human Computer interaction
Week-01-2.ppt BBB human Computer interactionWeek-01-2.ppt BBB human Computer interaction
Week-01-2.ppt BBB human Computer interactionfulawalesam
 
Call Girls Bannerghatta Road Just Call 👗 7737669865 👗 Top Class Call Girl Ser...
Call Girls Bannerghatta Road Just Call 👗 7737669865 👗 Top Class Call Girl Ser...Call Girls Bannerghatta Road Just Call 👗 7737669865 👗 Top Class Call Girl Ser...
Call Girls Bannerghatta Road Just Call 👗 7737669865 👗 Top Class Call Girl Ser...amitlee9823
 
Generative AI on Enterprise Cloud with NiFi and Milvus
Generative AI on Enterprise Cloud with NiFi and MilvusGenerative AI on Enterprise Cloud with NiFi and Milvus
Generative AI on Enterprise Cloud with NiFi and MilvusTimothy Spann
 
ALSO dropshipping via API with DroFx.pptx
ALSO dropshipping via API with DroFx.pptxALSO dropshipping via API with DroFx.pptx
ALSO dropshipping via API with DroFx.pptxolyaivanovalion
 
BDSM⚡Call Girls in Mandawali Delhi >༒8448380779 Escort Service
BDSM⚡Call Girls in Mandawali Delhi >༒8448380779 Escort ServiceBDSM⚡Call Girls in Mandawali Delhi >༒8448380779 Escort Service
BDSM⚡Call Girls in Mandawali Delhi >༒8448380779 Escort ServiceDelhi Call girls
 
Edukaciniai dropshipping via API with DroFx
Edukaciniai dropshipping via API with DroFxEdukaciniai dropshipping via API with DroFx
Edukaciniai dropshipping via API with DroFxolyaivanovalion
 
Chintamani Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore ...
Chintamani Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore ...Chintamani Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore ...
Chintamani Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore ...amitlee9823
 
Junnasandra Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore...
Junnasandra Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore...Junnasandra Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore...
Junnasandra Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore...amitlee9823
 

Recently uploaded (20)

Smarteg dropshipping via API with DroFx.pptx
Smarteg dropshipping via API with DroFx.pptxSmarteg dropshipping via API with DroFx.pptx
Smarteg dropshipping via API with DroFx.pptx
 
Digital Advertising Lecture for Advanced Digital & Social Media Strategy at U...
Digital Advertising Lecture for Advanced Digital & Social Media Strategy at U...Digital Advertising Lecture for Advanced Digital & Social Media Strategy at U...
Digital Advertising Lecture for Advanced Digital & Social Media Strategy at U...
 
Call Girls Indiranagar Just Call 👗 7737669865 👗 Top Class Call Girl Service B...
Call Girls Indiranagar Just Call 👗 7737669865 👗 Top Class Call Girl Service B...Call Girls Indiranagar Just Call 👗 7737669865 👗 Top Class Call Girl Service B...
Call Girls Indiranagar Just Call 👗 7737669865 👗 Top Class Call Girl Service B...
 
ELKO dropshipping via API with DroFx.pptx
ELKO dropshipping via API with DroFx.pptxELKO dropshipping via API with DroFx.pptx
ELKO dropshipping via API with DroFx.pptx
 
April 2024 - Crypto Market Report's Analysis
April 2024 - Crypto Market Report's AnalysisApril 2024 - Crypto Market Report's Analysis
April 2024 - Crypto Market Report's Analysis
 
Sampling (random) method and Non random.ppt
Sampling (random) method and Non random.pptSampling (random) method and Non random.ppt
Sampling (random) method and Non random.ppt
 
Ravak dropshipping via API with DroFx.pptx
Ravak dropshipping via API with DroFx.pptxRavak dropshipping via API with DroFx.pptx
Ravak dropshipping via API with DroFx.pptx
 
FESE Capital Markets Fact Sheet 2024 Q1.pdf
FESE Capital Markets Fact Sheet 2024 Q1.pdfFESE Capital Markets Fact Sheet 2024 Q1.pdf
FESE Capital Markets Fact Sheet 2024 Q1.pdf
 
Cheap Rate Call girls Sarita Vihar Delhi 9205541914 shot 1500 night
Cheap Rate Call girls Sarita Vihar Delhi 9205541914 shot 1500 nightCheap Rate Call girls Sarita Vihar Delhi 9205541914 shot 1500 night
Cheap Rate Call girls Sarita Vihar Delhi 9205541914 shot 1500 night
 
Week-01-2.ppt BBB human Computer interaction
Week-01-2.ppt BBB human Computer interactionWeek-01-2.ppt BBB human Computer interaction
Week-01-2.ppt BBB human Computer interaction
 
Call Girls Bannerghatta Road Just Call 👗 7737669865 👗 Top Class Call Girl Ser...
Call Girls Bannerghatta Road Just Call 👗 7737669865 👗 Top Class Call Girl Ser...Call Girls Bannerghatta Road Just Call 👗 7737669865 👗 Top Class Call Girl Ser...
Call Girls Bannerghatta Road Just Call 👗 7737669865 👗 Top Class Call Girl Ser...
 
Predicting Loan Approval: A Data Science Project
Predicting Loan Approval: A Data Science ProjectPredicting Loan Approval: A Data Science Project
Predicting Loan Approval: A Data Science Project
 
Generative AI on Enterprise Cloud with NiFi and Milvus
Generative AI on Enterprise Cloud with NiFi and MilvusGenerative AI on Enterprise Cloud with NiFi and Milvus
Generative AI on Enterprise Cloud with NiFi and Milvus
 
ALSO dropshipping via API with DroFx.pptx
ALSO dropshipping via API with DroFx.pptxALSO dropshipping via API with DroFx.pptx
ALSO dropshipping via API with DroFx.pptx
 
BDSM⚡Call Girls in Mandawali Delhi >༒8448380779 Escort Service
BDSM⚡Call Girls in Mandawali Delhi >༒8448380779 Escort ServiceBDSM⚡Call Girls in Mandawali Delhi >༒8448380779 Escort Service
BDSM⚡Call Girls in Mandawali Delhi >༒8448380779 Escort Service
 
Call Girls In Shalimar Bagh ( Delhi) 9953330565 Escorts Service
Call Girls In Shalimar Bagh ( Delhi) 9953330565 Escorts ServiceCall Girls In Shalimar Bagh ( Delhi) 9953330565 Escorts Service
Call Girls In Shalimar Bagh ( Delhi) 9953330565 Escorts Service
 
Edukaciniai dropshipping via API with DroFx
Edukaciniai dropshipping via API with DroFxEdukaciniai dropshipping via API with DroFx
Edukaciniai dropshipping via API with DroFx
 
Anomaly detection and data imputation within time series
Anomaly detection and data imputation within time seriesAnomaly detection and data imputation within time series
Anomaly detection and data imputation within time series
 
Chintamani Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore ...
Chintamani Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore ...Chintamani Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore ...
Chintamani Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore ...
 
Junnasandra Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore...
Junnasandra Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore...Junnasandra Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore...
Junnasandra Call Girls: 🍓 7737669865 🍓 High Profile Model Escorts | Bangalore...
 

NIPS2017 Few-shot Learning and Graph Convolution

  • 1. AI System Dept. System Management Unit Kazuki Fujikawa 7 202 & 2 1 70 2
  • 2. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n / / - / n / / -
  • 3. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n - - / - / n / / -
  • 4. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n FV b F caF iumw - C 1 8 C / 1 8 iuf n t pr m LI F So Nyf L nS lRL h de /8 1AA 1 /8 1 1 7 1A . 1 7 skg 01 CA 1 1 7 1 1 8 A 8 21 2 : C 2 7 1 7 1
  • 5. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n F N I F SF - C 1 8 C / 1 8 RIL V n / - /8 1AA 1 /8 1 1 7 1A . 1 7 01 CA 1 1 7 1 1 8 A 8 21 2 : C 2 7 1 7 1
  • 6. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. )) ) ) ( ) ( ) ) .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17 Learning to Learn ● What is Meta Learning / Learning to Learn? ○ Go beyond train/test from same distribution. ○ Task between train/test changes, so model has to “learn to learn” ● Datasets 132 Lake et al, 2013, 2015 I/O
  • 7. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. )) ) ) ( ) ( ) ) Learning to Learn 134 Losses .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
  • 8. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. )) ) ) ( ) ( ) ) Learning to Learn 134 Losses .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
  • 9. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. )) ) ) ( ) ( ) ) Learning to Learn 134 Losses .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
  • 10. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. )) ) ) ( ) ( ) ) Learning to Learn 134 Losses .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
  • 11. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. - n - 11 T’: Meta testing taskT: Meta training task dog frog horse ship truck airplane automobile bird cat deer dog frog horse ship truck airplane automobile bird cat deer • - : • - https://www.cs.toronto.edu/~kriz/cifar.html 1. Sample label set L from T 2. Sample a few images as support set S from L 3. Sample a few images as batch B from L 4. Optimize batch, Go to 1 S L Go Ger
  • 12. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. - n - . 12 T’: Meta testing taskT: Meta training task dog frog horse ship truck airplane automobile bird cat deer dog frog horse ship truck airplane automobile bird cat deer • T T : • M a N https://www.cs.toronto.edu/~kriz/cifar.html • T’ : 1. Sample label set L from T 2. Sample a few images as support set S from L 3. Sample a few images as batch B from L 4. Optimize batch, Go to 1 S L Go Ger
  • 13. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. - n - - 3 , 13 dog frog horse ship truck airplane automobile bird cat deer dog frog horse ship truck airplane automobile bird cat deer L’: Label set , 3 • : - 3 , automobile cat deer https://www.cs.toronto.edu/~kriz/cifar.html T’: Meta testing taskT: Meta training task 1. Sample label set L from T 2. Sample a few images as support set S from L 3. Sample a few images as batch B from L 4. Optimize batch, Go to 1 S L Go Ger
  • 14. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. - n - 14 dog frog horse ship truck airplane automobile bird cat deer dog frog horse ship truck airplane automobile bird cat deer L’: Label set S’: Support set : Query automobile cat deer 1- - - 1- 1- 1- 1- ˆx https://www.cs.toronto.edu/~kriz/cifar.html T’: Meta testing taskT: Meta training task 1. Sample label set L from T 2. Sample a few images as support set S from L 3. Sample a few images as batch B from L 4. Optimize batch, Go to 1 S L Go Ger
  • 15. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. - n - - 1 3 15 dog frog horse ship truck airplane automobile bird cat deer dog frog horse ship truck airplane automobile bird cat deer L’: Label set S’: Support set : Query automobile cat deer 31 - 31 31 ˆx • T LN :!" https://www.cs.toronto.edu/~kriz/cifar.html T’: Meta testing taskT: Meta training task 1. Sample label set L from T 2. Sample a few images as support set S from L 3. Sample a few images as batch B from L 4. Optimize batch, Go to 1 S L Go Ger
  • 16. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. )) ) ) ( ) ( ) ) Learning to Learn 134 Losses .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
  • 17. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. )) ) ) ( ) ( ) ) Losses Learning to Learn Model Based ● Santoro et al. ’16 ● Duan et al. ’17 ● Wang et al. ‘17 ● Munkhdalai & Yu ‘17 ● Mishra et al. ‘17 Optimization Based ● Schmidhuber ’87, ’92 ● Bengio et al. ’90, ‘92 ● Hochreiter et al. ’01 ● Li & Malik ‘16 ● Andrychowicz et al. ’16 ● Ravi & Larochelle ‘17 ● Finn et al ‘17 Metric Based ● Koch ’15 ● Vinyals et al. ‘16 ● Snell et al. ‘17 ● Shyam et al. ‘17 Trends: Learning to Learn / Meta Learning 135 .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
  • 18. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. , ( 2 - 16 - ,(-- 1 ),+ 0 n z il Ue gL lO Ue g ilPSo 0 T s e g Ma Rz cd m • z m V N ISI cd M 0 ) . 1 ).1 p t z LilO U cd m PRyu PSe g v Oriol Vinyals, NIPS 17 Model Based Meta Learning rn 2 : ) 7 ( ) 7 : A:
  • 19. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 1 6 2 1 + 1 , + n V sI Or zVih V O e msL 1 1 o V V h E !", $" v L 2 dg aS !", $" S Oh V t m w FPO ih NV uI 2 dg aS !", $" S Oh V t m ie n MOlSw e VpFPOih uL yu 0 12 : 7 0 2 0 20 -. 2 0 20 . 7: 7 Oriol Vinyals, NIPS 17 Matching Networks, Vinyals et al, NIPS 2016 Metric Based Meta Learning
  • 20. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 1 0 0 7 72 ,+70 + - n k M L /-/ F N LF / - /:!" i -/ - h F e • • Zc g f a • 2.1 Notation In few-shot classification we are given a small support set of N labeled examples S = {(x1, y1), . . . , (xN , yN )} where each xi 2 RD is the D-dimensional feature vector of an example and yi 2 {1, . . . , K} is the corresponding label. Sk denotes the set of examples labeled with class k. 2.2 Model Prototypical Networks compute an M-dimensional representation ck 2 RM , or prototype, of each class through an embedding function f : RD ! RM with learnable parameters . Each prototype is the mean vector of the embedded support points belonging to its class: ck = 1 |Sk| X (xi,yi)2Sk f (xi) (1) Given a distance function d : RM ⇥ RM ! [0, +1), Prototypical Networks produce a distribution over classes for a query point x based on a softmax over distances to the prototypes in the embedding space: p (y = k | x) = exp( d(f (x), ck)) P k0 exp( d(f (x), ck0 )) (2) Learning proceeds by minimizing the negative log-probability J( ) = log p (y = k | x) of the true class k via SGD. Training episodes are formed by randomly selecting a subset of classes from the training set, then choosing a subset of examples within each class to act as the support set and a 2 ing ers from few-shot learning in that instead of being given a support set of given a class meta-data vector vk for each class. These could be determined d be learned from e.g., raw text [8]. Modifying Prototypical Networks to deal s straightforward: we simply define ck = g#(vk) to be a separate embedding . An illustration of the zero-shot procedure for Prototypical Networks as ot procedure is shown in Figure 1. Since the meta-data vector and query ent input domains, we found it was helpful empirically to fix the prototype it length, however we do not constrain the query embedding f. we performed experiments on Omniglot [18] and the miniImageNet version with the splits proposed by Ravi and Larochelle [24]. We perform zero-shot 1 version of the Caltech UCSD bird dataset (CUB-200 2011) [34]. ot Classification et of 1623 handwritten characters collected from 50 alphabets. There are 20 Notation ew-shot classification we are given a small support set of N labeled examples S = , y1), . . . , (xN , yN )} where each xi 2 RD is the D-dimensional feature vector of an example yi 2 {1, . . . , K} is the corresponding label. Sk denotes the set of examples labeled with class k. Model otypical Networks compute an M-dimensional representation ck 2 RM , or prototype, of each through an embedding function f : RD ! RM with learnable parameters . Each prototype mean vector of the embedded support points belonging to its class: ck = 1 |Sk| X (xi,yi)2Sk f (xi) (1) n a distance function d : RM ⇥ RM ! [0, +1), Prototypical Networks produce a distribution classes for a query point x based on a softmax over distances to the prototypes in the embedding e: p (y = k | x) = exp( d(f (x), ck)) P k0 exp( d(f (x), ck0 )) (2) ning proceeds by minimizing the negative log-probability J( ) = log p (y = k | x) of the class k via SGD. Training episodes are formed by randomly selecting a subset of classes from aining set, then choosing a subset of examples within each class to act as the support set and a 2
  • 21. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 1 0 0 7 72 ,+70 + - n a i - • O > , • - I h aT AN E i c e M d F> SgC / i / > SgC - Table 1: Few-shot classification accuracies on Omniglot. ⇤ Uses non-standard train/test splits. 5-way Acc. 20-way Acc. Model Dist. Fine Tune 1-shot 5-shot 1-shot 5-shot MATCHING NETWORKS [32] Cosine N 98.1% 98.9% 93.8% 98.5% MATCHING NETWORKS [32] Cosine Y 97.9% 98.7% 93.5% 98.7% NEURAL STATISTICIAN [7] - N 98.1% 99.5% 93.2% 98.1% MAML [9]⇤ - N 98.7% 99.9% 95.8% 98.9% PROTOTYPICAL NETWORKS (OURS) Euclid. N 98.8% 99.7% 96.0% 98.9% Table 2: Few-shot classification accuracies on miniImageNet. All accuracy results are averaged over 600 test episodes and are reported with 95% confidence intervals. ⇤ Results reported by [24]. 5-way Acc. Model Dist. Fine Tune 1-shot 5-shot BASELINE NEAREST NEIGHBORS ⇤ Cosine N 28.86 ± 0.54% 49.79 ± 0.79% MATCHING NETWORKS [32]⇤ Cosine N 43.40 ± 0.78% 51.09 ± 0.71% MATCHING NETWORKS FCE [32]⇤ Cosine N 43.56 ± 0.84% 55.31 ± 0.73% META-LEARNER LSTM [24]⇤ - N 43.44 ± 0.77% 60.60 ± 0.71% MAML [9] - N 48.70 ± 1.84% 63.15 ± 0.91% PROTOTYPICAL NETWORKS (OURS) Euclid. N 49.42 ± 0.78% 68.20 ± 0.66% 3.2 miniImageNet Few-shot Classification The miniImageNet dataset, originally proposed by Vinyals et al. [32], is derived from the larger Figure 3: Comparison showing the effect of distance metric and number of classes per training episode on 5-way classification accuracy for both Matching Networks and Prototypical Networks on miniImageNet. The x-axis indicates configuration of the training episodes (way, distance, and shot), and the y-axis indicates 5-way test accuracy for the corresponding shot. Error bars indicate 95% confidence intervals as computed over 600 test episodes. Note that Matching Networks and Prototypical Networks are identical in the 1-shot case. Table 3: Zero-shot classification accuracies on CUB-200. Model Image Features 50-way Acc. 0-shot ALE [1] Fisher 26.9% SJE [2] AlexNet 40.3% SAMPLE CLUSTERING [19] AlexNet 44.3% SJE [2] GoogLeNet 50.1% DS-SJE [25] GoogLeNet 50.4% DA-SJE [25] GoogLeNet 50.9% SYNTHESIZED CLASSIFIERS [6] GoogLeNet 54.7% PROTOTYPICAL NETWORKS (OURS) GoogLeNet 54.8%
  • 22. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n uS D O yIl s m i yOG F CL M I y P 0 Sz pS g td OT D D LM yr yO • θ" = $" ⊙ θ"&' − )" ⊙ ∇+,-. ℒ" θ" = 1 ∗ θ"&' − η ∗ ∇+,-. ℒ" 0 3" = $" ⊙ 3"&' − )" ⊙ 43" • $", )" ∇+,-. ℒ", ℒ", θ"&' nMNoa θ6V 2 7 2 S D PIN y + 7 01 2 1 1 -, , Thus, we propose training a meta-learner LSTM to learn an update rule for traini work. We set the cell state of the LSTM to be the parameters of the learner, or c candidate cell state ˜ct = r✓t 1 Lt, given how valuable information about the grad mization. We define parametric forms for it and ft so that the meta-learner can de values through the course of the updates. Let us start with it, which corresponds to the learning rate for the updates. We let it = WI · ⇥ r✓t 1 Lt, Lt, ✓t 1, it 1 ⇤ + bI , meaning that the learning rate is a function of the current parameter value ✓t 1, the r✓t 1 Lt, the current loss Lt, and the previous learning rate it 1. With this inform learner should be able to finely control the learning rate so as to train the learne avoiding divergence. As for ft, it seems possible that the optimal choice isn’t the constant 1. Intuitive justify shrinking the parameters of the learner and forgetting part of its previous if the learner is currently in a bad local optima and needs a large change to esca correspond to a situation where the loss is high but the gradient is close to zero. Th for the forget gate is to have it be a function of that information, as well as the previ forget gate: ft = WF · ⇥ r✓t 1 Lt, Lt, ✓t 1, ft 1 ⇤ + bF . Additionally, notice that we can also learn the initial value of the cell state c0 for the it as a parameter of the meta-learner. This corresponds to the initial weights of th 3 cover classes not present in any of the datasets in Dmeta train (similarly, we ad meta-validation set that is used to determine hyper-parameters). where ✓t 1 are the parameters of the learner after t 1 updates, ↵t is the learn Lt is the loss optimized by the learner for its tth update, r✓t 1 Lt is the gradien respect to parameters ✓t 1, and ✓t is the updated parameters of the learner. Our key observation that we leverage here is that this update resembles the updat in an LSTM (Hochreiter & Schmidhuber, 1997) ct = ft ct 1 + it ˜ct, if ft = 1, ct 1 = ✓t 1, it = ↵t, and ˜ct = r✓t 1 Lt. Thus, we propose training a meta-learner LSTM to learn an update rule for train work. We set the cell state of the LSTM to be the parameters of the learner, or candidate cell state ˜ct = r✓t 1 Lt, given how valuable information about the gr mization. We define parametric forms for it and ft so that the meta-learner can d values through the course of the updates. Let us start with it, which corresponds to the learning rate for the updates. We let it = WI · ⇥ r✓t 1 Lt, Lt, ✓t 1, it 1 ⇤ + bI , meaning that the learning rate is a function of the current parameter value ✓t 1, th r✓t 1 Lt, the current loss Lt, and the previous learning rate it 1. With this infor learner should be able to finely control the learning rate so as to train the learn avoiding divergence. As for ft, it seems possible that the optimal choice isn’t the constant 1. Intuiti justify shrinking the parameters of the learner and forgetting part of its previou if the learner is currently in a bad local optima and needs a large change to esc correspond to a situation where the loss is high but the gradient is close to zero. T for the forget gate is to have it be a function of that information, as well as the pre forget gate: ft = WF · ⇥ r✓t 1 Lt, Lt, ✓t 1, ft 1 ⇤ + bF . Additionally, notice that we can also learn the initial value of the cell state c0 for th it as a parameter of the meta-learner. This corresponds to the initial weights of t 3 Oriol Vinyals, NIPS 17 Figure Credit: Hugo Larochelle Optimization Based Meta Learning v e - :7 1 27 : 7 - 2 : 2 2 , . 2 2 :
  • 23. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. (7 2 - 2 7 +2 ( 2 2 A ( - 0+ ,) - 1 n g r e t r DF C I r - r m ,i a n r p N I , I a n r Model-Agnostic Meta-Learning for Fast Adaptatio Algorithm 1 Model-Agnostic Meta-Learning Require: p(T ): distribution over tasks Require: ↵, : step size hyperparameters 1: randomly initialize ✓ 2: while not done do 3: Sample batch of tasks Ti ⇠ p(T ) 4: for all Ti do 5: Evaluate r✓LTi (f✓) with respect to K examples 6: Compute adapted parameters with gradient de- scent: ✓0 i = ✓ ↵r✓LTi (f✓) 7: end for 8: Update ✓ ✓ r✓ P Ti⇠p(T ) LTi (f✓0 i ) 9: end while products, wh braries such experiments, this backwar which we di 3. Species In this secti meta-learnin forcement le function and sented to the nism can be meta-learning learning/adaptation ✓ rL1 rL2 rL3 ✓⇤ 1 ✓⇤ 2 ✓⇤ 3 Figure 1. Diagram of our model-agnostic meta-learning algo- rithm (MAML), which optimizes for a representation ✓ that can quickly adapt to new tasks. Oriol Vinyals, NIPS 17 Summing Up Model Based Metric Based Optimization Based Oriol Vinyals, NIPS 17 Examples of Optimization Based Meta Learning Finn et al, 17 Ravi et al, 17
  • 24. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 ,2 0 7 1 - +, n 3 3 • 3 • 3 3 Hidden layers Hidden layers Temporal Dropout Neighborhood Attention + Temporal Convolution Attention over Demonstration Demonstration Current State Action ABlock# B C D E F G H I J Attention over Current State Context Network Demonstration Network Manipulation Network Context Embedding placed on the table.
  • 25. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 ,2 0 7 1 - +, n 13 3 :3 n hp iohp A u • e v N m rw d • e d N C 1 ** 1 3 3 3 • ls d g dbk T 3 33* 3 • D v N t N a 3 1 3 3 3 3 • dbk a * 3 3 3 Hidden layers Hidden layers Temporal Dropout Neighborhood Attention + Temporal Convolution Attention over Demonstration Demonstration Current State Action ABlock# B C D E F G H I J Attention over Current State Context Network Demonstration Network Manipulation Network Context Embedding One-Shot Imitation Learning h1 h2 h3 hN N Attentions i-th attention is applied to i-th block vs others Figure 4. Illustration of the neighborhood attention operation. It receives a list of embeddings for each block, performs one at- formance of the policy. After downsampling the demonstration, we apply a se- quence of operations, composed of dilated temporal con- volution (Yu & Koltun, 2016) and neighborhood attention. 4.2.3. CONTEXT NETWORK The context network is the crux of our model. Illustrated in Fig. 6, it processes both the current state and the embed- ding produced by the demonstration network, and outputs
  • 26. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 ,2 0 7 1 - +, n ,: , wCn u e N L S am • b , , n , • S , ,: MS , , , • d h n ,C , , C , kn • , , s i , , , , • v T M , , c An , xg , r o • vC S ,C At Hidden layers Hidden layers Temporal Dropout Neighborhood Attention + Temporal Convolution Attention over Demonstration Demonstration Current State Action ABlock# B C D E F G H I J Attention over Current State Context Network Demonstration Network Manipulation Network Context Embedding
  • 27. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 ,2 0 7 1 - +, n , , : M NP • L , • L M Hidden layers Hidden layers Temporal Dropout Neighborhood Attention + Temporal Convolution Attention over Demonstration Demonstration Current State Action ABlock# B C D E F G H I J Attention over Current State Context Network Demonstration Network Manipulation Network Context Embedding Figure 2: Illustration of the network architecture. In our experiments, we use p = 0.95, which reduces the length of demonstrations by a factor of 20. During test time, we can sample multiple downsampled trajectories, use each of them to compute downstream results, and average these results to produce an ensemble estimate. In our experience, this consistently improves the performance of the policy. Neighborhood Attention: After downsampling the demonstration, we apply a sequence of opera- tions, composed of dilated temporal convolution [65] and neighborhood attention. We now describe this second operation in more detail. From the figure, we can observe that for the easier tasks with fewer stages, all of the different conditioning strategies perform equally well and almost perfectly. As the difficulty (number of stages) increases, however, conditioning on the entire demonstration starts to outperform conditioning on the final state. One possible explanation is that when conditioned only on the final state, the policy may struggle about which block it should stack first, a piece of information that is readily accessible from demonstration, which not only communicates the task, but also provides valuable information to help accomplish it. More surprisingly, conditioning on the entire demonstration also seems to outperform conditioning on the snapshot, which we originally expected to perform the best. We suspect that this is due to the regularization effect introduced by temporal dropout, which effectively augments the set of demonstrations seen by the policy during training. Another interesting finding was that training with behavioral cloning has the same level of performance as training with DAGGER, which suggests that the entire training procedure could work without requiring interactive supervision. In our preliminary experiments, we found that injecting noise into the trajectory collection process was important for behavioral cloning to work well, hence in all experiments reported here we use noise injection. In practice, such noise can come from natural human-induced noise through tele-operation, or by artificially injecting additional noise before applying it on the physical robot. 5.2 Visualization We visualize the attention mechanisms underlying the main policy architecture to have a better understanding about how it operates. There are two kinds of attention we are mainly interested in, one where the policy attends to different time steps in the demonstration, and the other where the policy attends to different blocks in the current state. Fig. 4 shows some of the attention heatmaps. (a) Attention over blocks in the current state. (b) Attention over downsampled demonstration. Figure 4: Visualizing attentions performed by the policy during an entire execution. The task
  • 28. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 ,2 0 7 1 - +, n a” +, • ph p [ 2 1 : 1 : 5 • s o [ nvn R ] : 0 • nvn a A R R [E SR aer 1 1 • nv R a A R a 1 : • S M i F nvn a A R a i lga D G o S M i nvn G D O RA S” StcbR D
  • 29. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 0 ,2 0 7 1 - +, n l G B R n EF eiD p B C R R / 1 / R hA E G G A • e a S D so tG per task for training, and maintain a separate set of trajectories and initial configurations to be used for evaluation. The trajectories are collected using a hard-coded policy. 5.1 Performance Evaluation 1 2 3 4 5 6 7 Number of Stages 0% 20% 40% 60% 80% 100% AverageSuccessRate Policy Type Demo BC DAGGER Snapshot Final state (a) Performance on training tasks. 2 4 5 6 7 8 Number of Stages 0% 20% 40% 60% 80% 100% AverageSuccessRate Policy Type Demo BC DAGGER Snapshot Final state (b) Performance on test tasks. Figure 3: Comparison of different conditioning strategies. The darkest bar shows the performance of the hard-coded policy, which unsurprisingly performs the best most of the time. For architectures that use temporal dropout, we use an ensemble of 10 different downsampled demonstrations and average the action distributions. Then for all architectures we use the greedy action for evaluation. Fig. 3 shows the performance of various architectures. Results for training and test tasks are presented separately, where we group tasks by the number of stages required to complete them. This is because tasks that require more stages to complete are typically more challenging. In fact, even our scripted policy frequently fails on the hardest tasks. We measure success rate per task by executing the greedy policy (taking the most confident action at every time step) in 100 different configurations, each conditioned on a different demonstration unseen during training. We report the average success rate over all tasks within the same group.
  • 30. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 1 7 + 2 , 0 +- n a c bO n s Mb k pl BM L It iu M y dg M . : : : . / • .: : / .: :. : re : / .: : .: h M : :. a c bO S • .: : / .: :. : re :. : .: : _ bO o y iu Planning Module Planning with Visual Foresight [4,5] User Input/Task Specification How can we enable robots to learn vision-based manipulation skills that generalize to new objects & goals? One-Shot Visual Imitation Learning Planning with Visual Foresight Can robots reuse data from other tasks to adapt to new objects from only one visual demonstration? Our meta-learning approach: Learn to learn many other tasks using one demo training sets test sets meta-training tasks { task 1 task 2 … … Meta-testing: Learn new held-out task from 1 demo How can robots acquire general models and skills using entirely autonomously-collected data? Meta-Imitation Learning using MAML [1,2] meta-training time: - Learn from raw pixel observations 
 (rather than task-specific, engineered representations) - collect data with a diverse range of objects and environments - reuse data from other objects & tasks when learning to perform new task Collect data autonomously Predict future video for different actions [3,5] Demo: Robot placing, tasks correspond to different objects. - program initial motions, provide objects - record camera images and robot actions - no object supervision, camera calibration, human annotation, etc. Frederik Ebert, Chelsea Finn, Alex Lee, Sergey LevineChelsea Finn*, Tianhe Yu*, Tianhao Zhang, Pieter Abbeel, Sergey Levine Meta-training: Meta-training tasks: Standard robotics paradigm: Brittle, hand-engineered pipeline. RGB-D image segment objects estimate pose & physics of segments optimize action using estimated poses & physics Our approach: input image future predictions actions: [1] Finn, Abbeel, Levine. Model-Agnostic Meta- [2] Finn*, Yu*, Zhang, Abbeel, Levine. One-Shot Pla D G Use new objects from only one visual demonstration? Our meta-learning approach: Learn to learn many other tasks using one demo training sets test sets meta-training tasks { task 1 task 2 … … Meta-testing: Learn new held-out task from 1 demo u Meta-Imitation Learning using MAML [1,2] val demo meta-training time: training demo meta-training tasks meta-test time: demo of meta-test task with held-out objects Co Pre Sam 1. 2. 3. 4. policy architecture shown in demo Demo: Robot placing, tasks correspond to different objects. - - - Meta-training: One-Shot Imitation Learning Research Meta-training tasks: inp mh :: .. . . . .- -. :
  • 31. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + 2 - 1 7 1 21 1 2 1 21 , 0 n Meta-learning loss: n Task loss = behavioral cloning loss: [Pomerleau’89,Sammut’92] Learning a One-Shot Imitator with MAML [Finn*, Yu*, Zhang, Abbeel, Levine, 2017] Pieter Abbeel -- embody.ai / UC Berkeley / Gradescope C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7
  • 32. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + 2 - 1 7 1 21 1 2 1 21 , 0 C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7 n Meta-training targets / objects Robot Experiments: Learning to Place n Meta-testing targets / objects 1,300 demonstrations for meta-training Pieter Abbeel -- embody.ai / UC Berkeley / Gradescope[Finn*, Yu*, Zhang, Abbeel, Levine, 2017]
  • 33. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + 2 - 1 7 1 21 1 2 1 21 , 0 C A 2 2 0022 2 2 - 0 1 . 2 2 . 7: 7 Robot Experiments: Learning to Place 1 demo imitation Succes rate: 90% Pieter Abbeel -- embody.ai / UC Berkeley / Gradescope[Finn*, Yu*, Zhang, Abbeel, Levine, 2017]
  • 34. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n / / - / n / -
  • 35. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n p y r 00 100 w n N hc i Ld l n d l t sRu a k j Nov LGbge N L L G6 4 d l t -66 /62 7G. 2 D C 64 3 7 236:2 2 2 6 65 2 :2 65 2 7
  • 36. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. ( ( ) ) I/O Why Graphs? Natural Trend: 1. Fixed inputs / outputs 2. Tensor inputs / outputs 3. Sequential inputs / outputs 4. Graphs / Trees inputs / outputs 140Trends: Graph Networks .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
  • 37. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. ( ( ) ) Message Passing Neural Networks 144Trends: Graph Networks INPUT: ● Undirected graph G ● node features h_v (vector) ● edge features e_vw (vector) MPNN OUTPUT: ● Graph level target ● Family of neural networks which commute with order of vertices. ● Generalizes a few existing NN based methods. ● Related to the Weisfeiler-Lehman algorithm for GI testing. Arch .2 1 1 7 0 17 : 1 , 12 2 1 ,. 17
  • 38. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n , 2 2 7 2 , 0)7 : (,+ 1 R N) 2 ( wM]r )7 : [ o GP e L • , 2 2 7 2 + M • i h sL k gm ] • M]i h a d k gm uM] n + M C i h k gm + RI] • 2 2 , 2 2 7 R u Ni h p L l sL oP k gm tp ] ntum Chemistry Oriol Vinyals 3 George E. Dahl 1 DFT 103 seconds Message Passing Neural Net 10 2 seconds E,!0, ... Targets )7 : (,+
  • 39. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. Message Function: !"(ℎ% " , ℎ'( " , )%'( ) Σ Message Function: !"(ℎ% " , ℎ'( " , )%'( ) Neural Message Passing for time steps and is defined in terms of message functions Mt and vertex update functions Ut. During the message pass- ing phase, hidden states ht v at each node in the graph are updated based on messages mt+1 v according to mt+1 v = X w2N(v) Mt(ht v, ht w, evw) (1) ht+1 v = Ut(ht v, mt+1 v ) (2) where in the sum, N(v) denotes the neighbors of v in graph G. The readout phase computes a feature vector for the whole graph using some readout function R according to ˆy = R({hT v | v 2 G}). (3) The message functions Mt, vertex update functions Ut, and readout function R are all learned differentiable functions. R operates on the set of node states and must be invariant to permutations of the node states in order for the MPNN to be invariant to graph isomorphism. In what follows, we define previous models in the literature by specifying the message function M , vertex update function U , and readout func- n + : , : : 1 C 5 2 l icN ]E S eF[Uh • , : l , : !" ℎ% " , ℎ' " , )%' = ,-.,/0(ℎ' " , )%') l 0 5 : 1" ℎ% " , 2% "34 = 5(6" 789(%) 2% "34) • 6" 789(%) 0D : N ISda deg(:) MNfg L P v u1 u2 h(0) v h(0) u1 h(0) u2 Neural Message Passing for Quantum Chemistry time steps and is defined in terms of message functions Mt and vertex update functions Ut. During the message pass- ing phase, hidden states ht v at each node in the graph are updated based on messages mt+1 v according to mt+1 v = X w2N(v) Mt(ht v, ht w, evw) (1) ht+1 v = Ut(ht v, mt+1 v ) (2) where in the sum, N(v) denotes the neighbors of v in graph G. The readout phase computes a feature vector for the whole graph using some readout function R according to ˆy = R({hT v | v 2 G}). (3) The message functions Mt, vertex update functions Ut, and readout function R are all learned differentiable functions. R operates on the set of node states and must be invariant to permutations of the node states in order for the MPNN to be invariant to graph isomorphism. In what follows, we define previous models in the literature by specifying the message function Mt, vertex update function Ut, and readout func- tion R used. Note one could also learn edge features in an MPNN by introducing hidden states for all edges in the graph ht evw and updating them analogously to equations 1 and 2. Of the existing MPNNs, only Kearnes et al. (2016) Recurrent Unit introduced in Cho et al. (2014). This work used weight tying, so the same update function is used at each time step t. Finally, R = X v2V ⇣ i(h(T ) v , h0 v) ⌘ ⇣ j(h(T ) v ) ⌘ (4) where i and j are neural networks, and denotes element- wise multiplication. Interaction Networks, Battaglia et al. (2016) This work considered both the case where there is a tar- get at each node in the graph, and where there is a graph level target. It also considered the case where there are node level effects applied at each time step, in such a case the update function takes as input the concatenation (hv, xv, mv) where xv is an external vector representing some outside influence on the vertex v. The message func- tion M(hv, hw, evw) is a neural network which takes the concatenation (hv, hw, evw). The vertex update function U(hv, xv, mv) is a neural network which takes as input the concatenation (hv, xv, mv). Finally, in the case where there is a graph level output, R = f( P v2G hT v ) where f is a neural network which takes the sum of the final hidden states hT v . Note the original work only defined the model for T = 1. Molecular Graph Convolutions, Kearnes et al. (2016) )%'( )%'> Update Function: 1"(ℎ% " , 2% "34)
  • 40. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n ( 5 : D 5 + : C )DE 5D ,02 S R [ ] LI M N (+0 P a • 1 5 DC 5 1 5 DC D C ! ℎ# $ % ∈ ' = )( ∑#,- ./)0123 4-ℎ# - ) v u1 u2 h(0) v h(0) u1 h(0) u2 1 5 DC D C )( ∑#,- ./)0123 4-ℎ# - ) ℎ# (6) ℎ78 (6) ℎ79 (6) FFFFFF ℎ78 ($) ℎ79 (:) ℎ# ($) 5: 5 : 5: 5 : ;#78 ;#79 <= = !({ℎ# ($) |% ∈ '})
  • 41. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n , , 0 0 F C ,, 00 . - .1 ( 2 + 6 M ,12U] h iU g • CC : CC : C CC : 6 ) !" ℎ$ " , ℎ& " , '$& = )*&ℎ& " • )*&) Rde fa G a G 6 I [cL N 2 6 ) +" ℎ$ " , ,$ "-. = /0+ ℎ$ " , ,$ "-. Message Function: !"(ℎ$ " , ℎ&2 " , '$&2 ) Σ Message Function: !"(ℎ$ " , ℎ&2 " , '$&2 ) Neural Message Passing for time steps and is defined in terms of message functions Mt and vertex update functions Ut. During the message pass- ing phase, hidden states ht v at each node in the graph are updated based on messages mt+1 v according to mt+1 v = X w2N(v) Mt(ht v, ht w, evw) (1) ht+1 v = Ut(ht v, mt+1 v ) (2) where in the sum, N(v) denotes the neighbors of v in graph G. The readout phase computes a feature vector for the whole graph using some readout function R according to ˆy = R({hT v | v 2 G}). (3) The message functions Mt, vertex update functions Ut, and readout function R are all learned differentiable functions. R operates on the set of node states and must be invariant to permutations of the node states in order for the MPNN to be invariant to graph isomorphism. In what follows, we define previous models in the literature by specifying the message function M , vertex update function U , and readout func- v u1 u2 h(0) v h(0) u1 h(0) u2 Neural Message Passing for Quantum Chemistry time steps and is defined in terms of message functions Mt and vertex update functions Ut. During the message pass- ing phase, hidden states ht v at each node in the graph are updated based on messages mt+1 v according to mt+1 v = X w2N(v) Mt(ht v, ht w, evw) (1) ht+1 v = Ut(ht v, mt+1 v ) (2) where in the sum, N(v) denotes the neighbors of v in graph G. The readout phase computes a feature vector for the whole graph using some readout function R according to ˆy = R({hT v | v 2 G}). (3) The message functions Mt, vertex update functions Ut, and readout function R are all learned differentiable functions. R operates on the set of node states and must be invariant to permutations of the node states in order for the MPNN to be invariant to graph isomorphism. In what follows, we define previous models in the literature by specifying the message function Mt, vertex update function Ut, and readout func- tion R used. Note one could also learn edge features in an MPNN by introducing hidden states for all edges in the graph ht evw and updating them analogously to equations 1 and 2. Of the existing MPNNs, only Kearnes et al. (2016) Recurrent Unit introduced in Cho et al. (2014). This work used weight tying, so the same update function is used at each time step t. Finally, R = X v2V ⇣ i(h(T ) v , h0 v) ⌘ ⇣ j(h(T ) v ) ⌘ (4) where i and j are neural networks, and denotes element- wise multiplication. Interaction Networks, Battaglia et al. (2016) This work considered both the case where there is a tar- get at each node in the graph, and where there is a graph level target. It also considered the case where there are node level effects applied at each time step, in such a case the update function takes as input the concatenation (hv, xv, mv) where xv is an external vector representing some outside influence on the vertex v. The message func- tion M(hv, hw, evw) is a neural network which takes the concatenation (hv, hw, evw). The vertex update function U(hv, xv, mv) is a neural network which takes as input the concatenation (hv, xv, mv). Finally, in the case where there is a graph level output, R = f( P v2G hT v ) where f is a neural network which takes the sum of the final hidden states hT v . Note the original work only defined the model for T = 1. Molecular Graph Convolutions, Kearnes et al. (2016) '$&2 '$&4 Update Function: +"(ℎ$ " , ,$ "-.)
  • 42. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n + 6 + 6 6 C : ++ 1- ,)- 2 0 6 LFG+ 0 N GU • 6 6 6 ( ! ℎ# $ % ∈ ' = tanh( ∑# / 0 ℎ# $ , ℎ# 2 ⊙ tanh 4 ℎ# $ , ℎ# 2 ) • 0, 4( I / 0 ℎ# $ , ℎ# 2 ( 6 I R
  • 43. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n ( , : 0 (, +1 0 [ T ] US • ) 0 0 7: 0 ) 0 :1 7 : !" ℎ$ " , ℎ& " , '$& = tanh -./ -/.ℎ0 " + 23 ⊙ -5.'$0 + 26 • -./, -/., -5. D23, 26 M N • 20 :1 7 : 7" ℎ$ " , 8$ "93 = ℎ$ " + 8$ "93 Message Function: !"(ℎ$ " , ℎ&; " , '$&; ) Σ Message Function: !"(ℎ$ " , ℎ&; " , '$&; ) Neural Message Passing f time steps and is defined in terms of message functions Mt and vertex update functions Ut. During the message pass- ing phase, hidden states ht v at each node in the graph are updated based on messages mt+1 v according to mt+1 v = X w2N(v) Mt(ht v, ht w, evw) (1) ht+1 v = Ut(ht v, mt+1 v ) (2) where in the sum, N(v) denotes the neighbors of v in graph G. The readout phase computes a feature vector for the whole graph using some readout function R according to ˆy = R({hT v | v 2 G}). (3) The message functions Mt, vertex update functions Ut, and readout function R are all learned differentiable functions. R operates on the set of node states and must be invariant to permutations of the node states in order for the MPNN to be invariant to graph isomorphism. In what follows, we define previous models in the literature by specifying the message v u1 u2 h(0) v h(0) u1 h(0) u2 Neural Message Passing for Quantum Chemistry time steps and is defined in terms of message functions Mt and vertex update functions Ut. During the message pass- ing phase, hidden states ht v at each node in the graph are updated based on messages mt+1 v according to mt+1 v = X w2N(v) Mt(ht v, ht w, evw) (1) ht+1 v = Ut(ht v, mt+1 v ) (2) where in the sum, N(v) denotes the neighbors of v in graph G. The readout phase computes a feature vector for the whole graph using some readout function R according to ˆy = R({hT v | v 2 G}). (3) The message functions Mt, vertex update functions Ut, and readout function R are all learned differentiable functions. R operates on the set of node states and must be invariant to permutations of the node states in order for the MPNN to be invariant to graph isomorphism. In what follows, we define previous models in the literature by specifying the message function Mt, vertex update function Ut, and readout func- tion R used. Note one could also learn edge features in an MPNN by introducing hidden states for all edges in the graph ht evw and updating them analogously to equations 1 and 2. Of the existing MPNNs, only Kearnes et al. (2016) Recurrent Unit introduced in Cho et al. (2014). This work used weight tying, so the same update function is used at each time step t. Finally, R = X v2V ⇣ i(h(T ) v , h0 v) ⌘ ⇣ j(h(T ) v ) ⌘ (4) where i and j are neural networks, and denotes element- wise multiplication. Interaction Networks, Battaglia et al. (2016) This work considered both the case where there is a tar- get at each node in the graph, and where there is a graph level target. It also considered the case where there are node level effects applied at each time step, in such a case the update function takes as input the concatenation (hv, xv, mv) where xv is an external vector representing some outside influence on the vertex v. The message func- tion M(hv, hw, evw) is a neural network which takes the concatenation (hv, hw, evw). The vertex update function U(hv, xv, mv) is a neural network which takes as input the concatenation (hv, xv, mv). Finally, in the case where there is a graph level output, R = f( P v2G hT v ) where f is a neural network which takes the sum of the final hidden states hT v . Note the original work only defined the model for T = 1. '$&; '$&= Update Function: 7"(ℎ$ " , 8$ "93)
  • 44. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n (22: ,2 )2 7 )2 (,)) +0 ) 2 N D • 2 1 : 2 2 1 0 ! ℎ# $ % ∈ ' = ∑# NN(ℎ# $ )
  • 45. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n + 0 2 2 E E , - ( [ 1 6 G 2 2 M • EE6 C6EE C 6E [ EE6 :G 7 ) !" ℎ$ " , ℎ& " , '$& = )('$+)ℎ& " • )('$+)) N S R '$+ M IL00 [ C 6 :G 7 ) -" ℎ$ " , .$ "/0 = GRU ℎ$ " , .$ "/0 • ,,00 - 1 U Message Function: !"(ℎ$ " , ℎ&4 " , '$&4 ) Σ Message Function: !"(ℎ$ " , ℎ&4 " , '$&4 ) Neural Message Passing f time steps and is defined in terms of message functions Mt and vertex update functions Ut. During the message pass- ing phase, hidden states ht v at each node in the graph are updated based on messages mt+1 v according to mt+1 v = X w2N(v) Mt(ht v, ht w, evw) (1) ht+1 v = Ut(ht v, mt+1 v ) (2) where in the sum, N(v) denotes the neighbors of v in graph G. The readout phase computes a feature vector for the whole graph using some readout function R according to ˆy = R({hT v | v 2 G}). (3) The message functions Mt, vertex update functions Ut, and readout function R are all learned differentiable functions. R operates on the set of node states and must be invariant to permutations of the node states in order for the MPNN to be invariant to graph isomorphism. In what follows, we define previous models in the literature by specifying the message v u1 u2 h(0) v h(0) u1 h(0) u2 Neural Message Passing for Quantum Chemistry time steps and is defined in terms of message functions Mt and vertex update functions Ut. During the message pass- ing phase, hidden states ht v at each node in the graph are updated based on messages mt+1 v according to mt+1 v = X w2N(v) Mt(ht v, ht w, evw) (1) ht+1 v = Ut(ht v, mt+1 v ) (2) where in the sum, N(v) denotes the neighbors of v in graph G. The readout phase computes a feature vector for the whole graph using some readout function R according to ˆy = R({hT v | v 2 G}). (3) The message functions Mt, vertex update functions Ut, and readout function R are all learned differentiable functions. R operates on the set of node states and must be invariant to permutations of the node states in order for the MPNN to be invariant to graph isomorphism. In what follows, we define previous models in the literature by specifying the message function Mt, vertex update function Ut, and readout func- tion R used. Note one could also learn edge features in an MPNN by introducing hidden states for all edges in the graph ht evw and updating them analogously to equations 1 and 2. Of the existing MPNNs, only Kearnes et al. (2016) Recurrent Unit introduced in Cho et al. (2014). This work used weight tying, so the same update function is used at each time step t. Finally, R = X v2V ⇣ i(h(T ) v , h0 v) ⌘ ⇣ j(h(T ) v ) ⌘ (4) where i and j are neural networks, and denotes element- wise multiplication. Interaction Networks, Battaglia et al. (2016) This work considered both the case where there is a tar- get at each node in the graph, and where there is a graph level target. It also considered the case where there are node level effects applied at each time step, in such a case the update function takes as input the concatenation (hv, xv, mv) where xv is an external vector representing some outside influence on the vertex v. The message func- tion M(hv, hw, evw) is a neural network which takes the concatenation (hv, hw, evw). The vertex update function U(hv, xv, mv) is a neural network which takes as input the concatenation (hv, xv, mv). Finally, in the case where there is a graph level output, R = f( P v2G hT v ) where f is a neural network which takes the sum of the final hidden states hT v . Note the original work only defined the model for T = 1. '$&4 '$&5 Update Function: -"(ℎ$ " , .$ "/0)
  • 46. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n + 0 2 2 C C , - ( 1 6 E L2 2 • 1 6 E :E 7 ) ! ℎ# $ % ∈ ' = set2set ℎ# $ % ∈ ' C C G6 C - 1 -. ∗ RM00M SLIN size of the set, and which is order invariant. In the next sections, we explain such a modification, which could also be seen as a special case of a Memory Network (Weston et al., 2015) or Neural Turing Machine (Graves et al., 2014) – with a computation flow as depicted in Figure 1. 4.2 ATTENTION MECHANISMS Neural models with memories coupled to differentiable addressing mechanism have been success- fully applied to handwriting generation and recognition (Graves, 2012), machine translation (Bah- danau et al., 2015a), and more general computation machines (Graves et al., 2014; Weston et al., 2015). Since we are interested in associative memories we employed a “content” based attention. This has the property that the vector retrieved from our memory would not change if we randomly shuffled the memory. This is crucial for proper treatment of the input set X as such. In particular, our process block based on an attention mechanism uses the following: qt = LSTM(q⇤ t 1) (3) ei,t = f(mi, qt) (4) ai,t = exp(ei,t) P j exp(ej,t) (5) rt = X i ai,tmi (6) q⇤ t = [qt rt] (7) Read Process Write Figure 1: The Read-Process-and-Write model. where i indexes through each memory vector mi (typically equal to the cardinality of X), qt is a query vector which allows us to read rt from the memories, f is a function that computes a single scalar from mi and qt (e.g., a dot product), and LSTM is an LSTM which computes a recurrent state but which takes no inputs. q⇤ t is the state which this LSTM evolves, and is formed by concatenating the query qt with the resulting attention readout rt. t is the index which indicates V ) G6 C - 1
  • 47. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + , + 1 n • • 3 Figure 2: Illustration of SchNet with an architectural overview (left), the interaction block (middle) and the continuous-filter convolution with filter-generating network (right). The shifted softplus is defined as ssp(x) = ln(0.5ex + 0.5). : https://www.slideshare.net/KazukiFujikawa/schnet-a-continuousfilter-convolutional-neural-network-for-modeling-quantum- interactions
  • 48. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + , + 1 n 2 . 727 . + )2 2 7 .2 ) )2 C eft), the interaction block (middle) ork (right). The shifted softplus is Zi + (+ !"# = %" − %# (a) 1st interaction block (b) 2nd interaction block Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte SchNet trained on molecular dynamics of ethanol. Negative values ar Filter-generating networks The cfconv layer including its filter-ge at the right panel of Fig. 2. In order to satisfy the requirements for we restrict our filters for the cfconv layers to be rotationally invarian obtained by using interatomic distances dij = kri rjk as input for the filter network. Without further processing, the filters w a neural network after initialization is close to linear. This leads to training that is hard to overcome. We avoid this by expanding the dista ek(ri rj) = exp( kdij µkk2 ) located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is occurring in the data sets are covered by the filters. Due to this addit filters are less correlated leading to a faster training procedure. Choos to reducing the resolution of the filter, while restricting the range of filter size in a usual convolutional layer. An extensive evaluation of th left for future work. We feed the expanded distances into two dense l to compute the filter weight W(ri rj) as shown in Fig. 2 (right). Fig 3 shows 2d-cuts through generated filters for all three interaction an ethanol molecular dynamics trajectory. We observe how each filter interatomic distances. This enables its interaction block to update the re radial environment of each atom. The sequential updates from three in to construct highly complex many-body representations in the spirit o rotational invariance due to the radial filters. 4.2 Training with energies and forces As described above, the interatomic forces are related to the molecula an energy-conserving force model by differentiating the energy mode ˆFi(Z1, . . . , Zn, r1, . . . , rn) = @ ˆE (Z1, . . . , Zn, r + + -. + 2 3 7 + (+ Zj’ Zj + + -. + 2 3 7 + (+ +
  • 49. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 1 + 1 , n 1 12 .7 , 2 1 , , p e i Co m eft), the interaction block (middle) ork (right). The shifted softplus is Zi ( ).3+.- ) !"# = %" − %# (a) 1st interaction block (b) 2nd interaction block Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte SchNet trained on molecular dynamics of ethanol. Negative values ar Filter-generating networks The cfconv layer including its filter-ge at the right panel of Fig. 2. In order to satisfy the requirements for we restrict our filters for the cfconv layers to be rotationally invarian obtained by using interatomic distances dij = kri rjk as input for the filter network. Without further processing, the filters w a neural network after initialization is close to linear. This leads to training that is hard to overcome. We avoid this by expanding the dista ek(ri rj) = exp( kdij µkk2 ) located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is occurring in the data sets are covered by the filters. Due to this addit filters are less correlated leading to a faster training procedure. Choos to reducing the resolution of the filter, while restricting the range of filter size in a usual convolutional layer. An extensive evaluation of th left for future work. We feed the expanded distances into two dense l to compute the filter weight W(ri rj) as shown in Fig. 2 (right). Fig 3 shows 2d-cuts through generated filters for all three interaction an ethanol molecular dynamics trajectory. We observe how each filter interatomic distances. This enables its interaction block to update the re radial environment of each atom. The sequential updates from three in to construct highly complex many-body representations in the spirit o rotational invariance due to the radial filters. 4.2 Training with energies and forces As described above, the interatomic forces are related to the molecula an energy-conserving force model by differentiating the energy mode ˆFi(Z1, . . . , Zn, r1, . . . , rn) = @ ˆE (Z1, . . . , Zn, r -. . 01 .- 2 .3+.- Zj’ Zj -. . 01 .- 2 .3+.- + 2 7 0 7- 2 0 7 '( = 0.1Å, '. = 0.2Å, … '122 = 30Å 4 = 10Å 7+ ( b !"# hC' d h n lC h 0 c f
  • 50. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 1 + 1 , n 327 2 3 - 7 32 3 7 32 - 32 b a eft), the interaction block (middle) ork (right). The shifted softplus is Zi ) + !"# = %" − %# (a) 1st interaction block (b) 2nd interaction block Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte SchNet trained on molecular dynamics of ethanol. Negative values ar Filter-generating networks The cfconv layer including its filter-ge at the right panel of Fig. 2. In order to satisfy the requirements for we restrict our filters for the cfconv layers to be rotationally invarian obtained by using interatomic distances dij = kri rjk as input for the filter network. Without further processing, the filters w a neural network after initialization is close to linear. This leads to training that is hard to overcome. We avoid this by expanding the dista ek(ri rj) = exp( kdij µkk2 ) located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is occurring in the data sets are covered by the filters. Due to this addit filters are less correlated leading to a faster training procedure. Choos to reducing the resolution of the filter, while restricting the range of filter size in a usual convolutional layer. An extensive evaluation of th left for future work. We feed the expanded distances into two dense l to compute the filter weight W(ri rj) as shown in Fig. 2 (right). Fig 3 shows 2d-cuts through generated filters for all three interaction an ethanol molecular dynamics trajectory. We observe how each filter interatomic distances. This enables its interaction block to update the re radial environment of each atom. The sequential updates from three in to construct highly complex many-body representations in the spirit o rotational invariance due to the radial filters. 4.2 Training with energies and forces As described above, the interatomic forces are related to the molecula an energy-conserving force model by differentiating the energy mode ˆFi(Z1, . . . , Zn, r1, . . . , rn) = @ ˆE (Z1, . . . , Zn, r + 2 . -7 + 3-7 ) + Zj’ Zj + 2 . -7 + 3-7 ) + + 2 7 0 7- 2 0 7 C ) + 73 c ( 7 (7 32 ) + 73
  • 51. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + , + 1 n 2 . 727 . + )2 2 7 .2 ) )2 C eft), the interaction block (middle) ork (right). The shifted softplus is Zi + (+ !"# = %" − %# (a) 1st interaction block (b) 2nd interaction block Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filte SchNet trained on molecular dynamics of ethanol. Negative values ar Filter-generating networks The cfconv layer including its filter-ge at the right panel of Fig. 2. In order to satisfy the requirements for we restrict our filters for the cfconv layers to be rotationally invarian obtained by using interatomic distances dij = kri rjk as input for the filter network. Without further processing, the filters w a neural network after initialization is close to linear. This leads to training that is hard to overcome. We avoid this by expanding the dista ek(ri rj) = exp( kdij µkk2 ) located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is occurring in the data sets are covered by the filters. Due to this addit filters are less correlated leading to a faster training procedure. Choos to reducing the resolution of the filter, while restricting the range of filter size in a usual convolutional layer. An extensive evaluation of th left for future work. We feed the expanded distances into two dense l to compute the filter weight W(ri rj) as shown in Fig. 2 (right). Fig 3 shows 2d-cuts through generated filters for all three interaction an ethanol molecular dynamics trajectory. We observe how each filter interatomic distances. This enables its interaction block to update the re radial environment of each atom. The sequential updates from three in to construct highly complex many-body representations in the spirit o rotational invariance due to the radial filters. 4.2 Training with energies and forces As described above, the interatomic forces are related to the molecula an energy-conserving force model by differentiating the energy mode ˆFi(Z1, . . . , Zn, r1, . . . , rn) = @ ˆE (Z1, . . . , Zn, r + + -. + 2 3 7 + (+ Zj’ Zj + + -. + 2 3 7 + (+ +
  • 52. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + , + 1 n e D • I a I c N • I B T SchNet with an architectural overview (left), the interaction block (middle) r convolution with filter-generating network (right). The shifted softplus is 0.5ex + 0.5). (a) 1st interaction block (b) 2nd interaction block (c) 3rd interaction block Figure 3: 10x10 Å cuts through all 64 radial, three-dimensional filters in each interaction block of SchNet trained on molecular dynamics of ethanol. Negative values are blue, positive values are red. Filter-generating networks The cfconv layer including its filter-generating network are depicted at the right panel of Fig. 2. In order to satisfy the requirements for modeling molecular energies,
  • 53. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + , + 1 n , - S 1 1 N Figure 2: Illustration of SchNet with an architectural overview (left), the interaction block (middle) and the continuous-filter convolution with filter-generating network (right). The shifted softplus is defined as ssp(x) = ln(0.5ex + 0.5). 4.1 Architecture
  • 54. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + , + 1 n 011 a l C [C • m: aeoh • m: aeohC : 2 7 • ! C C L • mi : ] , + 100,000 0.34 0.84 – – 110,462 0.31 – 0.45 0.33 We include the total energy E as well as forces Fi in the training loss to train a neural network that performs well on both properties: `( ˆE, (E, F1, . . . , Fn)) = kE ˆEk2 + ⇢ n nX i=0 Fi @ ˆE @Ri ! 2 . (5) This kind of loss has been used before for fitting a restricted potential energy surfaces with MLPs [36]. n our experiments, we use ⇢ = 0 in Eq. 5 for pure energy based training and ⇢ = 100 for combined energy and force training. The value of ⇢ was optimized empirically to account for different scales of energy and forces. Due to the relation of energies and forces reflected in the model, we expect to see improved gen- eralization, however, at a computational cost. As we need to perform a full forward and backward pass on the energy model to obtain the forces, the resulting force model is twice as deep and, hence, equires about twice the amount of computation time. Even though the GDML model captures this relationship between energies and forces, it is explicitly optimized to predict the force field while the energy prediction is a by-product. Models such as circular fingerprints [15], molecular graph convolutions or message-passing neural networks[19] for property prediction across chemical compound space are only concerned with equilibrium molecules, .e., the special case where the forces are vanishing. They can not be trained with forces in a similar manner, as they include discontinuities in their predicted potential energy surface caused by discrete binning or the use of one-hot encoded bond type information. 5 Experiments and results 0.59 0.94 – – 0.34 0.84 – – 0.31 – 0.45 0.33 as well as forces Fi in the training loss to train a neural network that ies: . . , Fn)) = kE ˆEk2 + ⇢ n nX i=0 Fi @ ˆE @Ri ! 2 . (5) before for fitting a restricted potential energy surfaces with MLPs [36]. = 0 in Eq. 5 for pure energy based training and ⇢ = 100 for combined value of ⇢ was optimized empirically to account for different scales of es and forces reflected in the model, we expect to see improved gen- mputational cost. As we need to perform a full forward and backward btain the forces, the resulting force model is twice as deep and, hence, nt of computation time. l captures this relationship between energies and forces, it is explicitly e field while the energy prediction is a by-product. Models such as ecular graph convolutions or message-passing neural networks[19] for rgy predictions in kcal/mol on the QM9 data set with given . DTNN [18] enn-s2s [19] enn-s2s-ens5 [19] 0.94 – – 0.84 – – – 0.45 0.33 as forces Fi in the training loss to train a neural network that = kE ˆEk2 + ⇢ n nX i=0 Fi @ ˆE @Ri ! 2 . (5) or fitting a restricted potential energy surfaces with MLPs [36]. q. 5 for pure energy based training and ⇢ = 100 for combined ⇢ was optimized empirically to account for different scales of rces reflected in the model, we expect to see improved gen- al cost. As we need to perform a full forward and backward forces, the resulting force model is twice as deep and, hence, mputation time. s this relationship between energies and forces, it is explicitly hile the energy prediction is a by-product. Models such as aph convolutions or message-passing neural networks[19] for mpound space are only concerned with equilibrium molecules, re vanishing. They can not be trained with forces in a similar in their predicted potential energy surface caused by discrete bond type information. ek(ri rj) = exp( kdij µkk ) located at centers 0Å  µk  30Å every 0.1Å with = 10Å. This is chosen such that all distances occurring in the data sets are covered by the filters. Due to this additional non-linearity, the initial filters are less correlated leading to a faster training procedure. Choosing fewer centers corresponds to reducing the resolution of the filter, while restricting the range of the centers corresponds to the filter size in a usual convolutional layer. An extensive evaluation of the impact of these variables is left for future work. We feed the expanded distances into two dense layers with softplus activations to compute the filter weight W(ri rj) as shown in Fig. 2 (right). Fig 3 shows 2d-cuts through generated filters for all three interaction blocks of SchNet trained on an ethanol molecular dynamics trajectory. We observe how each filter emphasizes certain ranges of interatomic distances. This enables its interaction block to update the representations according to the radial environment of each atom. The sequential updates from three interaction blocks allow SchNet to construct highly complex many-body representations in the spirit of DTNNs [18] while keeping rotational invariance due to the radial filters. 4.2 Training with energies and forces As described above, the interatomic forces are related to the molecular energy, so that we can obtain an energy-conserving force model by differentiating the energy model w.r.t. the atom positions ˆFi(Z1, . . . , Zn, r1, . . . , rn) = @ ˆE @ri (Z1, . . . , Zn, r1, . . . , rn). (4) Chmiela et al. [17] pointed out that this leads to an energy-conserving force-field by construction. As SchNet yields rotationally invariant energy predictions, the force predictions are rotationally equivariant by construction. The model has to be at least twice differentiable to allow for gradient descent of the force loss. We chose a shifted softplus ssp(x) = ln(0.5ex + 0.5) as non-linearity throughout the network in order to obtain a smooth potential energy surface. The shifting ensures that
  • 55. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. + , + 1 n G1 + S lM[ nrU U eFca • U Qm U G uSUnr U i F U TN • S r s SI N L t • 00 52 :D 07 9 9 5, 9 - 9 9 G9 U] hi Table 1: Mean absolute errors for energy predictions in kcal/mol on the QM9 data set with given training set size N. Best model in bold. N SchNet DTNN [18] enn-s2s [19] enn-s2s-ens5 [19] 50,000 0.59 0.94 – – 100,000 0.34 0.84 – – 110,462 0.31 – 0.45 0.33 We include the total energy E as well as forces Fi in the training loss to train a neural network that performs well on both properties:
  • 56. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 1 1 2 1 10 7 7 1 , + n C G N C • G N C / c • G N C / • / a Merge Fully- Connected Graph Convolution Classification Ligand Protein Graph Receptor Protein Graph Graph Convolution Graph Convolution Graph Convolution Residue Representation Residue Pair Representation R1 R2 R3 R1 R2 R3 R1 R2 R3 R1 R2 R3 R1 R2 R3 R1 R1 R1 R2 R2 R2 R3 R3 R3 Figure 2: An overview of the pairwise classification architecture. Each neighborhood of a residue i proteins is processed using one or more graph convolution layers, with weight sharing between le network. The activations generated by the convolutional layers are merged by concatenating them, fol Node Residue Conservation / Composition Accessible Surface Area Residue Depth Protrusion Index Edge Distance Angle protein 1: Graph convolution on protein structures. Left: Each residue in a protein is a node in a graph where the rhood of a node is the set of neighboring nodes in the protein structure; each node has features computed amino acid sequence and structure, and edges have features describing the relative distance and angle n residues. Right: Schematic description of the convolution operator which has as its receptive field a set hboring residues, and produces an activation which is associated with the center residue. orhood of size k, Fin input features and Fout output features is O(kFinFoutn). Construction of
  • 57. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 1 1 2 1 10 7 7 1 , + n : O F C GAE M / 3 N • 3 / • 3 / / • of the neighbors of a node. Our objective is to design convolution operators that can be applied to graphs without a regular structure, and without imposing a particular order on the neighbors of a given node. To summarize, we would like to learn a mapping at each node in the graph which has the form: zi = W (xi, {xn1 , . . . , xnk }), where {n1, . . . , nk} are the neighbors of node i that define the receptive field of the convolution, is a non-linear activation function, and W are its learned parameters; the dependence on the neighboring nodes as a set represents our intention to learn a function that is order-independent. We present the following two realizations of this operator that provides the output of a set of filters in a neighborhood of a node of interest that we refer to as the "center node": zi = ✓ W C xi + 1 |Ni| X j2Ni W N xj + b ◆ , (1) where Ni is the set of neighbors of node i, W C is the weight matrix associated with the center node, W N is the weight matrix associated with neighboring nodes, and b is a vector of biases, one for each filter. The dimensionality of the weight matrices is determined by the dimensionality of the inputs and the number of filters. The computational complexity of this operator on a graph with n nodes, a 2 Figure 1: Graph convolution on protein structures. Left: Each residue in a protein is a node in a graph where neighborhood of a node is the set of neighboring nodes in the protein structure; each node has features comp from its amino acid sequence and structure, and edges have features describing the relative distance and a between residues. Right: Schematic description of the convolution operator which has as its receptive field of neighboring residues, and produces an activation which is associated with the center residue. neighborhood of size k, Fin input features and Fout output features is O(kFinFoutn). Constructio the neighborhood is straightforward using a preprocessing step that takes O(n2 log n). In order to provide for some differentiation between neighbors, we incorporate features on the ed between each neighbor and the center node as follows: zi = ✓ W C xi + 1 |Ni| X j2Ni W N xj + 1 |Ni| X j2Ni W E Aij + b ◆ , where W E is the weight matrix associated with edge features. For comparison with order-independent methods we propose an order-dependent method, wh order is determined by distance from the center node. In this method each neighbor has unique we matrices for nodes and edges: zi = ✓ W C xi + 1 |Ni| X j2Ni W N j xj + 1 |Ni| X j2Ni W E j Aij + b ◆ . Here W N j /W E j are the weight matrices associated with the jth node or the edges connecting to the nodes, respectively. This operator is inspired by the PATCHY-SAN method of Niepert et al. [16]. more flexible than the order-independent convolutional operators, allowing the learning of distinct between neighbors at the cost of significantly more parameters. Multiple layers of these graph convolution operators can be used, and this will have the ef of learning features that characterize the graph at increasing levels of abstraction, and will allow information to propagate through the graph, thereby integrating information across region increasing size. Furthermore, these operators are rotation-invariant if the features have this prop In convolutional networks, inputs are often downsampled based on the size and stride of the recep field. It is also common to use pooling to further reduce the size of the input. Our graph opera on the other hand maintain the structure of the graph, which is necessary for the protein interf prediction problem, where we classify pairs of nodes from different graphs, rather than en graphs. Using convolutional architectures that use only convolutional layers without downsamplin common practice in the area of graph convolutional networks, especially if classification is perform neighborhood of size k, Fin input features and Fout output features is O(kFinFoutn). Construction of the neighborhood is straightforward using a preprocessing step that takes O(n2 log n). In order to provide for some differentiation between neighbors, we incorporate features on the edges between each neighbor and the center node as follows: zi = ✓ W C xi + 1 |Ni| X j2Ni W N xj + 1 |Ni| X j2Ni W E Aij + b ◆ , (2) where W E is the weight matrix associated with edge features. For comparison with order-independent methods we propose an order-dependent method, where order is determined by distance from the center node. In this method each neighbor has unique weight matrices for nodes and edges: zi = ✓ W C xi + 1 |Ni| X j2Ni W N j xj + 1 |Ni| X j2Ni W E j Aij + b ◆ . (3) Here W N j /W E j are the weight matrices associated with the jth node or the edges connecting to the jth nodes, respectively. This operator is inspired by the PATCHY-SAN method of Niepert et al. [16]. It is more flexible than the order-independent convolutional operators, allowing the learning of distinctions between neighbors at the cost of significantly more parameters. Multiple layers of these graph convolution operators can be used, and this will have the effect of learning features that characterize the graph at increasing levels of abstraction, and will also allow information to propagate through the graph, thereby integrating information across regions of increasing size. Furthermore, these operators are rotation-invariant if the features have this property. In convolutional networks, inputs are often downsampled based on the size and stride of the receptive field. It is also common to use pooling to further reduce the size of the input. Our graph operators on the other hand maintain the structure of the graph, which is necessary for the protein interface prediction problem, where we classify pairs of nodes from different graphs, rather than entire graphs. Using convolutional architectures that use only convolutional layers without downsampling is common practice in the area of graph convolutional networks, especially if classification is performed at the node or edge level. This practice has support from the success of networks without pooling layers in the realm of object recognition [23]. The downside of not downsampling is higher memory and computational costs. Related work. Several authors have recently proposed graph convolutional operators that generalize Method Convolutional Layers 1 2 3 4 No Convolution 0.812 (0.007) 0.810 (0.006) 0.808 (0.006) 0.796 (0.006) Diffusion (DCNN) (2 hops) [5] 0.790 (0.014) – – – Diffusion (DCNN) (5 hops) [5]) 0.828 (0.018) – – – Single Weight Matrix (MFN [9]) 0.865 (0.007) 0.871 (0.013) 0.873 (0.017) 0.869 (0.017) Node Average (Equation (1)) 0.864 (0.007) 0.882 (0.007) 0.891 (0.005) 0.889 (0.005) Node and Edge Average (Equation (2)) 0.876 (0.005) 0.898 (0.005) 0.895 (0.006) 0.889 (0.007) DTNN [21] 0.867 (0.007) 0.880 (0.007) 0.882 (0.008) 0.873 (0.012) Order Dependent (Equation (3)) 0.854 (0.004) 0.873 (0.005) 0.891 (0.004) 0.889 (0.008) Table 2: Median area under the receiver operating characteristic curve (AUC) across all complexes in the test set for various graph convolutional methods. Results shown are the average and standard deviation over ten runs with different random seeds. Networks have the following number of filters for 1, 2, 3, and 4 layers before merging, respectively: (256), (256, 512), (256, 256, 512), (256, 256, 512, 512). The exception is the DTNN method, which by necessity produces an output which is has the same dimensionality as its input. Unlike the other methods, diffusion convolution performed best with an RBF with a standard deviation of 2Å. After merging, all networks have a dense layer with 512 hidden units followed by a binary classification layer. Bold faced values indicate best performance for each method.
  • 58. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 2 A 0 2 7 211 , 2 2 + n S . . o a ib a e G o a e c : d • o l a . . P • o n . o g e . o hm . Figure 1: Scene graphs are defined by the objects in an image (vertices) and their interactions (edges). The ability to express information about the connections between objects make scene graphs a useful representation for many computer vision tasks including captioning and visual question answering.
  • 59. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 2 A 0 2 7 211 , 2 2 + n C 2 1 1 1 2 2 . N Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are extracted from the pixel locations of top activations and fed through fully connected networks to
  • 60. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 2 A 0 2 7 211 , 2 2 + n e 2 3, • D 3, c I I 3 1 c I , I I 3 , .d Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are extracted from the pixel locations of top activations and fed through fully connected networks to
  • 61. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 2 A 0 2 7 211 , 2 2 + n I / 2 • 1 • 2 D 2 D • D L not indicate which vertex serves as the source and which serves as the destination, nor does it disambiguate between pairs of vertices that happen to share the same midpoint. To train the network to produce a coherent set of embeddings we build off of the loss penalty used in [20]. During training, we have a ground truth set of annotations defining the unique objects in the scene and the edges between these objects. This allows us to enforce two penalties: that an edge points to a vertex by matching its output embedding as closely as possible, and that the embedding vectors produced for each vertex are sufficiently different. We think of the first as “pulling together” all references to a single vertex, and the second as “pushing apart” the references to different individual vertices. We consider an embedding hi 2 Rd produced for a vertex vi 2 V . All edges that connect to this vertex produce a set of embeddings h0 ik, k = 1, ..., Ki where Ki is the total number of references to that vertex. Given an image with n objects the loss to “pull together” these embeddings is: Lpull = 1 Pn i=1 Ki nX i=1 KiX k=1 (hi h0 ik)2 To “push apart” embeddings across different vertices we first used the penalty described in [20], but experienced difficulty with convergence. We tested alternatives and the most reliable loss was a margin-based penalty similar to [9]: Lpush = n 1X i=1 nX j=i+1 max(0, m ||hi hj||) Intuitively, Lpush is at its highest the closer hi and hj are to each other. The penalty drops off sharply as the distance between hi and hj grows, eventually hitting zero once the distance is greater than a given margin m. On the flip side, for some edge connected to a vertex vi, the loss Lpull will quickly grow the further its reference embedding h0 i is from hi. The two penalties are weighted equally leaving a final associative embedding loss of Lpull + Lpush. In this work, we use m = 8 and d = 8. Convergence of the network improves greatly after increasing the dimension d of tags up from 1 as used in [20]. Once the network is trained with this loss, full construction of the graph can be performed with a trivial postprocessing step. The network produces a pool of vertex and edge detections. For every edge, we look at the source and destination embeddings and match them to the closest embedding amongst the detected vertices. Multiple edges may have the same source and target vertices, vs and vt, and it is also possible for vs to equal vt. Figure 2: Full pipeline for object and relationship detection. A network is trained to produce two heatmaps that activate at the predicted locations of objects and relationships. Feature vectors are extracted from the pixel locations of top activations and fed through fully connected networks to To train the network to produce a coherent set of embeddings we build off of the loss penalty used in [20]. During training, we have a ground truth set of annotations defining the unique objects in the scene and the edges between these objects. This allows us to enforce two penalties: that an edge points to a vertex by matching its output embedding as closely as possible, and that the embedding vectors produced for each vertex are sufficiently different. We think of the first as “pulling together” all references to a single vertex, and the second as “pushing apart” the references to different individual vertices. We consider an embedding hi 2 Rd produced for a vertex vi 2 V . All edges that connect to this vertex produce a set of embeddings h0 ik, k = 1, ..., Ki where Ki is the total number of references to that vertex. Given an image with n objects the loss to “pull together” these embeddings is: Lpull = 1 Pn i=1 Ki nX i=1 KiX k=1 (hi h0 ik)2 To “push apart” embeddings across different vertices we first used the penalty described in [20], but experienced difficulty with convergence. We tested alternatives and the most reliable loss was a margin-based penalty similar to [9]: Lpush = n 1X i=1 nX j=i+1 max(0, m ||hi hj||) Intuitively, Lpush is at its highest the closer hi and hj are to each other. The penalty drops off sharply as the distance between hi and hj grows, eventually hitting zero once the distance is greater than a given margin m. On the flip side, for some edge connected to a vertex vi, the loss Lpull will quickly grow the further its reference embedding h0 i is from hi. The two penalties are weighted equally leaving a final associative embedding loss of Lpull + Lpush. In this work, we use m = 8 and d = 8. Convergence of the network improves greatly after increasing the dimension d of tags up from 1 as used in [20]. Once the network is trained with this loss, full construction of the graph can be performed with a trivial postprocessing step. The network produces a pool of vertex and edge detections. For every edge, we look at the source and destination embeddings and match them to the closest embedding amongst the detected vertices. Multiple edges may have the same source and target vertices, vs and vt, and it is also possible for vs to equal vt. 5
  • 62. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. 2 A 0 2 7 211 , 2 2 + n S e a@ G P • d@ C R b c : • d C b c : • 3 d G C c : SGGen (no RPN) SGGen (w/ RPN) SGCls PredCls R@50 R@100 R@50 R@100 R@50 R@100 R@50 R@100 Lu et al. [18] – – 0.3 0.5 11.8 14.1 27.9 35.0 Xu et al. [26] – – 3.4 4.2 21.7 24.4 44.8 53.0 Our model 6.7 7.8 9.7 11.3 26.5 30.0 68.0 75.2 Table 1: Results on Visual Genome Figure 3: Predictions on Visual Genome. In the top row, the network must produce all object and relationship detections directly from the image. The second row includes examples from an easier version of the task where object detections are provided. Relationships outlined in green correspond to predictions that correctly matched to a ground truth annotation. the course of training. We set so = 3 and sr = 6 which is sufficient to completely accommodate the detection annotations for all but a small fraction of cases. & @ c
  • 63. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n 2FV RIOS 7F NJNH Y NSO O" M" FS L FS LF NJNH VJSI MFMO HMFNSF NF L NFSVO KR Y JN LR" : JOL" FS L SDIJNH NFSVO KR O ONF RIOS LF NJNH NDFR JN 9F L 4N O M SJON ODFRRJNH RSFMR Y NFLL" 5 KF" 6F JN VF RK " N JDI AFMFL OSOS PJD L NFSVO KR O FV RIOS LF NJNH NDFR JN 9F L 4N O M SJON ODFRRJNH RSFMR , Y J" DIJN" N 3 HO 7 ODIFLLF :PSJMJX SJON R MO FL O FV RIOS LF NJNH Y 2JNN" 0IFLRF " JFSF CCFFL" N F HF 7F JNF O FL HNORSJD MFS LF NJNH O RS PS SJON O FFP NFSVO KR J P FP JNS J . , ( () , Y 1 N" N" FS L :NF RIOS JMJS SJON LF NJNH NDFR JN NF L JN O M SJON P ODFRRJNH R RSFMR , Y 2JNN" 0IFLRF " FS L :NF RIOS JR L JMJS SJON LF NJNH J MFS LF NJNH J P FP JNS J . , - )- ,
  • 64. Copyright (C) DeNA Co.,Ltd. All Rights Reserved. n 3P NF 0 LT JS G L 3GJKCP" 5SQ GL" C J 8CSP J KCQQ C N QQGL D P S L SK AFCKGQ P P GT NPCNPGL P GT. ) 1STCL S " 1 TG " C J 0 LT JS G L J LC PIQ L P NFQ D P JC PLGL K JCASJ P DGL CPNPGL Q T LACQ GL LCSP J GLD PK G L NP ACQQGL Q Q CKQ 7G" S G " C J 3 C P NF QC SCLAC LCSP J LC PIQ P GT NPCNPGL P GT. )-( AFY " PGQ D " C J :S L SK AFCKGA J GLQG F Q DP K CCN CLQ P LCSP J LC PIQ 8 SPC A KKSLGA G LQ , . (,- AFY " PGQ D" C J AF8C . A L GLS SQ DGJ CP A LT JS G L J LCSP J LC PI D P K CJGL S L SK GL CP A G LQ T LACQ GL 8CSP J 4LD PK G L 9P ACQQGL Q CKQ 2 S " JC " C J 9P CGL 4L CPD AC 9PC GA G L SQGL 3P NF 0 LT JS G L J 8C PIQ T LACQ GL 8CSP J 4LD PK G L 9P ACQQGL Q CKQ 8C CJJ" JC L P " L 5G 1CL 9G CJQ P NFQ QQ AG GTC CK C GL T LACQ GL 8CSP J 4LD PK G L 9P ACQQGL Q CKQ