Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Paul Primus
dcase2020_task2
Commits
5baf14c5
Commit
5baf14c5
authored
Jun 03, 2020
by
Paul Primus
Browse files
Switched to stable MADE implementation
parent
2366952c
Changes
12
Hide whitespace changes
Inline
Side-by-side
dcase2020_task2/experiments/baseline_experiment.py
View file @
5baf14c5
...
...
@@ -18,7 +18,7 @@ class BaselineExperiment(BaseExperiment, pl.LightningModule):
def
__init__
(
self
,
configuration_dict
,
_run
):
super
().
__init__
(
configuration_dict
)
self
.
network
=
self
.
objects
[
'
auto_encoder_
model'
]
self
.
network
=
self
.
objects
[
'model'
]
self
.
reconstruction
=
self
.
objects
[
'reconstruction'
]
self
.
logger_
=
Logger
(
_run
,
self
,
self
.
configuration_dict
,
self
.
objects
)
...
...
@@ -143,25 +143,24 @@ def configuration():
}
reconstruction
=
{
'class'
:
'dcase2020_task2.losses.MSEReconstruction'
,
'class'
:
'dcase2020_task2.losses.NLLReconstruction'
,
'args'
:
[
'@data_set.observation_shape'
,
],
'kwargs'
:
{
'weight'
:
1.0
,
'input_shape'
:
'@data_set.observation_shape'
'weight'
:
1.0
}
}
auto_encoder_
model
=
{
model
=
{
'class'
:
'dcase2020_task2.models.MADE'
,
'args'
:
[
'@data_set.observation_shape'
,
'@reconstruction'
,
{
}
'@reconstruction'
],
'kwargs'
:
{
'hidden_size
s
'
:
[
4096
,
4096
,
4096
,
4096
],
'n
atural_ordering'
:
True
'hidden_size'
:
4096
,
'n
um_hidden'
:
4
}
}
...
...
@@ -179,7 +178,7 @@ def configuration():
optimizer
=
{
'class'
:
'torch.optim.Adam'
,
'args'
:
[
'@
auto_encoder_
model.parameters()'
'@model.parameters()'
],
'kwargs'
:
{
'lr'
:
learning_rate
,
...
...
dcase2020_task2/losses/__init__.py
View file @
5baf14c5
from
dcase2020_task2.losses.base_loss
import
BaseReconstruction
,
BaseLoss
from
dcase2020_task2.losses.mse_loss
import
MSEReconstruction
from
dcase2020_task2.losses.nll_loss
import
NLLReconstruction
from
dcase2020_task2.losses.np_loss
import
NP
from
dcase2020_task2.losses.auc_loss
import
AUC
...
...
dcase2020_task2/losses/nll_loss.py
0 → 100644
View file @
5baf14c5
from
dcase2020_task2.losses
import
BaseReconstruction
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
as
nn
,
distributions
as
D
import
numpy
as
np
class
NLLReconstruction
(
BaseReconstruction
):
def
__init__
(
self
,
input_shape
,
weight
=
1.0
,
**
kwargs
):
super
().
__init__
()
self
.
weight
=
weight
self
.
input_shape
=
input_shape
# base distribution for calculation of log prob under the model
self
.
register_buffer
(
'base_dist_mean'
,
torch
.
zeros
(
np
.
prod
(
input_shape
)))
self
.
register_buffer
(
'base_dist_var'
,
torch
.
ones
(
np
.
prod
(
input_shape
)))
@
property
def
base_dist
(
self
):
return
D
.
Normal
(
self
.
base_dist_mean
,
self
.
base_dist_var
)
def
forward
(
self
,
batch
):
# prepare observations and prediction based on loss type:
# use linear outputs & normalized observations as is
# MAF eq 4 -- return mean and log std
batch
[
'm'
],
batch
[
'loga'
]
=
batch
[
'pre_reconstructions'
].
chunk
(
chunks
=
2
,
dim
=
1
)
# this guys should be normally distributed....
batch
[
'u'
]
=
(
batch
[
'observations'
].
view
(
len
(
batch
[
'observations'
]),
-
1
)
-
batch
[
'm'
])
*
torch
.
exp
(
-
batch
[
'loga'
])
# MAF eq 5
batch
[
'log_abs_det_jacobian'
]
=
-
batch
[
'loga'
]
# log probability
batch
[
'log_proba'
]
=
torch
.
sum
(
self
.
base_dist
.
log_prob
(
batch
[
'u'
])
+
batch
[
'log_abs_det_jacobian'
],
dim
=
1
)
# scores
batch
[
'scores'
]
=
-
batch
[
'log_proba'
]
batch
[
'visualizations'
]
=
batch
[
'u'
].
view
(
-
1
,
*
self
.
input_shape
)
# loss
batch
[
'reconstruction_loss_raw'
]
=
-
batch
[
'log_proba'
].
mean
()
batch
[
'reconstruction_loss'
]
=
self
.
weight
*
batch
[
'reconstruction_loss_raw'
]
return
batch
dcase2020_task2/models/__init__.py
View file @
5baf14c5
from
dcase2020_task2.models.base_model
import
ClassifierBase
,
VAEBase
from
dcase2020_task2.models.fc_baseline
import
BaselineFCAE
,
BaselineFCNN
from
dcase2020_task2.models.fc_sampling
import
SamplingFCAE
,
SamplingFCGenerator
,
SamplingCRNNAE
from
dcase2020_task2.models.cnn_baseline
import
BaselineCNN
from
dcase2020_task2.models.fc_reduced
import
ReducedFCAE
from
dcase2020_task2.models.made
import
MADE
\ No newline at end of file
dcase2020_task2/models/
fc_reduced
.py
→
dcase2020_task2/models/
ae
.py
View file @
5baf14c5
...
...
@@ -3,41 +3,51 @@ from dcase2020_task2.models import VAEBase
import
numpy
as
np
import
torch
def
init_weights
(
m
):
if
type
(
m
)
==
torch
.
nn
.
Linear
:
torch
.
nn
.
init
.
xavier_uniform_
(
m
.
weight
,
gain
=
torch
.
nn
.
init
.
calculate_gain
(
'relu'
))
m
.
bias
.
data
.
fill_
(
0.01
)
from
dcase2020_task2.models.custom
import
activation_dict
,
init_weights
class
ReducedFCAE
(
torch
.
nn
.
Module
,
VAEBase
):
class
BaselineFCAE
(
torch
.
nn
.
Module
,
VAEBase
):
def
__init__
(
self
,
input_shape
,
prior
,
reconstruction_loss
,
prior
hidden_size
=
128
,
num_hidden
=
3
,
activation
=
'relu'
,
batch_norm
=
False
):
super
().
__init__
()
activation_fn
=
activation_dict
[
activation
]
self
.
input_shape
=
input_shape
self
.
prior
=
prior
self
.
reconstruction
=
reconstruction_loss
self
.
encoder
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
np
.
prod
(
input_shape
),
256
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
Linear
(
256
,
128
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
Linear
(
128
,
prior
.
input_size
),
)
self
.
decoder
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
prior
.
latent_size
,
128
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
Linear
(
128
,
256
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
Linear
(
256
,
np
.
prod
(
input_shape
)),
)
# encoder sizes/ layers
sizes
=
[
np
.
prod
(
input_shape
)
]
+
[
hidden_size
]
*
num_hidden
+
[
prior
.
input_size
]
encoder_layers
=
[]
for
i
,
o
in
zip
(
sizes
[:
-
1
],
sizes
[
1
:]):
encoder_layers
.
append
(
torch
.
nn
.
Linear
(
i
,
o
))
if
batch_norm
:
encoder_layers
.
append
(
torch
.
nn
.
BatchNorm1d
(
o
))
encoder_layers
.
append
(
activation_fn
())
# symetric decoder sizes/ layers
sizes
=
sizes
[::
-
1
]
decoder_layers
=
[]
for
i
,
o
in
zip
(
sizes
[:
-
1
],
sizes
[
1
:]):
decoder_layers
.
append
(
torch
.
nn
.
Linear
(
i
,
o
))
if
batch_norm
:
decoder_layers
.
append
(
torch
.
nn
.
BatchNorm1d
(
o
))
decoder_layers
.
append
(
activation_fn
())
# remove last relu
_
=
decoder_layers
.
pop
()
self
.
encoder
=
torch
.
nn
.
Sequential
(
*
encoder_layers
)
self
.
decoder
=
torch
.
nn
.
Sequential
(
*
decoder_layers
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
batch
):
...
...
dcase2020_task2/models/cnn_baseline.py
deleted
100644 → 0
View file @
2366952c
import
torch.nn
from
dcase2020_task2.models
import
VAEBase
import
numpy
as
np
import
torch
class
BaselineCNN
(
torch
.
nn
.
Module
,
VAEBase
):
def
__init__
(
self
,
input_shape
,
reconstruction_loss
,
prior
):
super
().
__init__
()
self
.
input_shape
=
input_shape
self
.
prior
=
prior
self
.
reconstruction
=
reconstruction_loss
self
.
encoder
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
input_shape
[
0
],
32
,
(
3
,
3
),
padding
=
1
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
MaxPool2d
((
2
,
1
)),
torch
.
nn
.
Conv2d
(
32
,
64
,
(
3
,
3
),
padding
=
1
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
MaxPool2d
((
2
,
1
)),
torch
.
nn
.
Conv2d
(
64
,
128
,
(
3
,
3
),
padding
=
1
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
MaxPool2d
((
2
,
1
)),
torch
.
nn
.
Conv2d
(
128
,
prior
.
latent_size
,
(
3
,
3
),
padding
=
1
),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
)
self
.
decoder
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
ConvTranspose2d
(
prior
.
latent_size
,
prior
.
latent_size
,
(
input_shape
[
-
2
]
//
8
,
input_shape
[
-
1
])),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
ConvTranspose2d
(
prior
.
latent_size
,
128
,
(
3
,
3
),
stride
=
(
2
,
1
),
padding
=
(
1
,
1
),
output_padding
=
(
1
,
0
)),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
ConvTranspose2d
(
128
,
64
,
(
3
,
3
),
stride
=
(
2
,
1
),
padding
=
(
1
,
1
),
output_padding
=
(
1
,
0
)),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
ConvTranspose2d
(
64
,
32
,
(
3
,
3
),
stride
=
(
2
,
1
),
padding
=
(
1
,
1
),
output_padding
=
(
1
,
0
)),
torch
.
nn
.
ReLU
(
True
),
torch
.
nn
.
Conv2d
(
32
,
input_shape
[
0
],
(
3
,
3
),
padding
=
1
)
)
def
forward
(
self
,
batch
):
batch
=
self
.
encode
(
batch
)
batch
=
self
.
prior
(
batch
)
batch
=
self
.
decode
(
batch
)
return
batch
def
encode
(
self
,
batch
):
x
=
batch
[
'observations'
]
batch
[
'pre_codes'
]
=
self
.
encoder
(
x
).
view
(
x
.
shape
[
0
],
-
1
)
return
batch
def
decode
(
self
,
batch
):
batch
[
'pre_reconstructions'
]
=
self
.
decoder
(
batch
[
'codes'
].
view
(
-
1
,
self
.
prior
.
latent_size
,
1
,
1
))
batch
=
self
.
reconstruction
(
batch
)
return
batch
'''
from priors.no_prior import NoPrior
from losses.mse_loss import MSE
import torch
input_shape = (1, 128, 5)
prior = NoPrior(latent_size=8)
mse = MSE()
cnn = BaselineCNN(input_shape, mse, prior)
input = {
'observations':torch.ones(1, *input_shape)
}
cnn(input)
'''
\ No newline at end of file
dcase2020_task2/models/custom.py
0 → 100644
View file @
5baf14c5
import
copy
import
math
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
activation_dict
=
{
'relu'
:
torch
.
nn
.
ReLU
,
'tanh'
:
torch
.
nn
.
Tanh
}
def
init_weights
(
m
):
if
type
(
m
)
==
torch
.
nn
.
Linear
:
torch
.
nn
.
init
.
xavier_uniform_
(
m
.
weight
,
gain
=
torch
.
nn
.
init
.
calculate_gain
(
'relu'
))
m
.
bias
.
data
.
fill_
(
0.01
)
# --------------------
# Model layers and helpers
# --------------------
def
create_masks
(
input_size
,
hidden_size
,
n_hidden
,
input_order
=
'sequential'
,
input_degrees
=
None
):
# MADE paper sec 4:
# degrees of connections between layers -- ensure at most in_degree - 1 connections
degrees
=
[]
# set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades);
# else init input degrees based on strategy in input_order (sequential or random)
if
input_order
==
'sequential'
:
degrees
+=
[
torch
.
arange
(
input_size
)]
if
input_degrees
is
None
else
[
input_degrees
]
for
_
in
range
(
n_hidden
+
1
):
degrees
+=
[
torch
.
arange
(
hidden_size
)
%
(
input_size
-
1
)]
degrees
+=
[
torch
.
arange
(
input_size
)
%
input_size
-
1
]
if
input_degrees
is
None
else
[
input_degrees
%
input_size
-
1
]
elif
input_order
==
'random'
:
degrees
+=
[
torch
.
randperm
(
input_size
)]
if
input_degrees
is
None
else
[
input_degrees
]
for
_
in
range
(
n_hidden
+
1
):
min_prev_degree
=
min
(
degrees
[
-
1
].
min
().
item
(),
input_size
-
1
)
degrees
+=
[
torch
.
randint
(
min_prev_degree
,
input_size
,
(
hidden_size
,))]
min_prev_degree
=
min
(
degrees
[
-
1
].
min
().
item
(),
input_size
-
1
)
degrees
+=
[
torch
.
randint
(
min_prev_degree
,
input_size
,
(
input_size
,))
-
1
]
if
input_degrees
is
None
else
[
input_degrees
-
1
]
# construct masks
masks
=
[]
for
(
d0
,
d1
)
in
zip
(
degrees
[:
-
1
],
degrees
[
1
:]):
masks
+=
[(
d1
.
unsqueeze
(
-
1
)
>=
d0
.
unsqueeze
(
0
)).
float
()]
return
masks
,
degrees
[
0
]
class
MaskedLinear
(
nn
.
Linear
):
""" MADE building block layer """
def
__init__
(
self
,
input_size
,
n_outputs
,
mask
,
cond_label_size
=
None
):
super
().
__init__
(
input_size
,
n_outputs
)
self
.
register_buffer
(
'mask'
,
mask
)
self
.
cond_label_size
=
cond_label_size
if
cond_label_size
is
not
None
:
self
.
cond_weight
=
nn
.
Parameter
(
torch
.
rand
(
n_outputs
,
cond_label_size
)
/
math
.
sqrt
(
cond_label_size
))
def
forward
(
self
,
x
,
y
=
None
):
out
=
F
.
linear
(
x
,
self
.
weight
*
self
.
mask
,
self
.
bias
)
if
y
is
not
None
:
out
=
out
+
F
.
linear
(
y
,
self
.
cond_weight
)
return
out
def
extra_repr
(
self
):
return
'in_features={}, out_features={}, bias={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
)
+
(
self
.
cond_label_size
!=
None
)
*
', cond_features={}'
.
format
(
self
.
cond_label_size
)
class
LinearMaskedCoupling
(
nn
.
Module
):
""" Modified RealNVP Coupling Layers per the MAF paper """
def
__init__
(
self
,
input_size
,
hidden_size
,
n_hidden
,
mask
,
cond_label_size
=
None
):
super
().
__init__
()
self
.
register_buffer
(
'mask'
,
mask
)
# scale function
s_net
=
[
nn
.
Linear
(
input_size
+
(
cond_label_size
if
cond_label_size
is
not
None
else
0
),
hidden_size
)]
for
_
in
range
(
n_hidden
):
s_net
+=
[
nn
.
Tanh
(),
nn
.
Linear
(
hidden_size
,
hidden_size
)]
s_net
+=
[
nn
.
Tanh
(),
nn
.
Linear
(
hidden_size
,
input_size
)]
self
.
s_net
=
nn
.
Sequential
(
*
s_net
)
# translation function
self
.
t_net
=
copy
.
deepcopy
(
self
.
s_net
)
# replace Tanh with ReLU's per MAF paper
for
i
in
range
(
len
(
self
.
t_net
)):
if
not
isinstance
(
self
.
t_net
[
i
],
nn
.
Linear
):
self
.
t_net
[
i
]
=
nn
.
ReLU
()
def
forward
(
self
,
x
,
y
=
None
):
# apply mask
mx
=
x
*
self
.
mask
# run through model
s
=
self
.
s_net
(
mx
if
y
is
None
else
torch
.
cat
([
y
,
mx
],
dim
=
1
))
t
=
self
.
t_net
(
mx
if
y
is
None
else
torch
.
cat
([
y
,
mx
],
dim
=
1
))
# cf RealNVP eq 8 where u corresponds to x (here we're modeling u)
u
=
mx
+
(
1
-
self
.
mask
)
*
(
x
-
t
)
*
torch
.
exp
(
-
s
)
# log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob
log_abs_det_jacobian
=
-
(
1
-
self
.
mask
)
*
s
return
u
,
log_abs_det_jacobian
def
inverse
(
self
,
u
,
y
=
None
):
# apply mask
mu
=
u
*
self
.
mask
# run through model
s
=
self
.
s_net
(
mu
if
y
is
None
else
torch
.
cat
([
y
,
mu
],
dim
=
1
))
t
=
self
.
t_net
(
mu
if
y
is
None
else
torch
.
cat
([
y
,
mu
],
dim
=
1
))
# cf RealNVP eq 7
x
=
mu
+
(
1
-
self
.
mask
)
*
(
u
*
s
.
exp
()
+
t
)
# log det dx/du
log_abs_det_jacobian
=
(
1
-
self
.
mask
)
*
s
return
x
,
log_abs_det_jacobian
class
BatchNorm
(
nn
.
Module
):
""" RealNVP BatchNorm layer """
def
__init__
(
self
,
input_size
,
momentum
=
0.9
,
eps
=
1e-5
):
super
().
__init__
()
self
.
momentum
=
momentum
self
.
eps
=
eps
self
.
log_gamma
=
nn
.
Parameter
(
torch
.
zeros
(
input_size
))
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
input_size
))
self
.
register_buffer
(
'running_mean'
,
torch
.
zeros
(
input_size
))
self
.
register_buffer
(
'running_var'
,
torch
.
ones
(
input_size
))
def
forward
(
self
,
x
,
cond_y
=
None
):
if
self
.
training
:
self
.
batch_mean
=
x
.
mean
(
0
)
self
.
batch_var
=
x
.
var
(
0
)
# note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False)
# update running mean
self
.
running_mean
.
mul_
(
self
.
momentum
).
add_
(
self
.
batch_mean
.
data
*
(
1
-
self
.
momentum
))
self
.
running_var
.
mul_
(
self
.
momentum
).
add_
(
self
.
batch_var
.
data
*
(
1
-
self
.
momentum
))
mean
=
self
.
batch_mean
var
=
self
.
batch_var
else
:
mean
=
self
.
running_mean
var
=
self
.
running_var
# compute normalized input (cf original batch norm paper algo 1)
x_hat
=
(
x
-
mean
)
/
torch
.
sqrt
(
var
+
self
.
eps
)
y
=
self
.
log_gamma
.
exp
()
*
x_hat
+
self
.
beta
# compute log_abs_det_jacobian (cf RealNVP paper)
log_abs_det_jacobian
=
self
.
log_gamma
-
0.5
*
torch
.
log
(
var
+
self
.
eps
)
return
y
,
log_abs_det_jacobian
.
expand_as
(
x
)
def
inverse
(
self
,
y
,
cond_y
=
None
):
if
self
.
training
:
mean
=
self
.
batch_mean
var
=
self
.
batch_var
else
:
mean
=
self
.
running_mean
var
=
self
.
running_var
x_hat
=
(
y
-
self
.
beta
)
*
torch
.
exp
(
-
self
.
log_gamma
)
x
=
x_hat
*
torch
.
sqrt
(
var
+
self
.
eps
)
+
mean
log_abs_det_jacobian
=
0.5
*
torch
.
log
(
var
+
self
.
eps
)
-
self
.
log_gamma
return
x
,
log_abs_det_jacobian
.
expand_as
(
x
)
class
FlowSequential
(
nn
.
Sequential
):
""" Container for layers of a normalizing flow """
def
forward
(
self
,
x
,
y
):
sum_log_abs_det_jacobians
=
0
for
module
in
self
:
x
,
log_abs_det_jacobian
=
module
(
x
,
y
)
sum_log_abs_det_jacobians
=
sum_log_abs_det_jacobians
+
log_abs_det_jacobian
return
x
,
sum_log_abs_det_jacobians
def
inverse
(
self
,
u
,
y
):
sum_log_abs_det_jacobians
=
0
for
module
in
reversed
(
self
):
u
,
log_abs_det_jacobian
=
module
.
inverse
(
u
,
y
)
sum_log_abs_det_jacobians
=
sum_log_abs_det_jacobians
+
log_abs_det_jacobian
return
u
,
sum_log_abs_det_jacobians
dcase2020_task2/models/fc_baseline.py
deleted
100644 → 0
View file @
2366952c
import
torch.nn
from
dcase2020_task2.models
import
VAEBase
import
numpy
as
np
import
torch
def
init_weights
(
m
):
if
type
(
m
)
==
torch
.
nn
.
Linear
:
torch
.
nn
.
init
.
xavier_uniform_
(
m
.
weight
,
gain
=
torch
.
nn
.
init
.
calculate_gain
(
'relu'
))
m
.
bias
.
data
.
fill_
(
0.01
)
class
BaselineFCAE
(
torch
.
nn
.
Module
,
VAEBase
):
def
__init__
(
self
,
input_shape
,
reconstruction_loss
,
prior
):
super
().
__init__
()
self
.
input_shape
=
input_shape
self
.
prior
=
prior
self
.
reconstruction
=
reconstruction_loss
self
.
encoder
=
torch
.
nn
.
Sequential
(
# 1
torch
.
nn
.
Linear
(
np
.
prod
(
input_shape
),
128
),
torch
.
nn
.
BatchNorm1d
(
128
),
torch
.
nn
.
ReLU
(
True
),
# 2
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
BatchNorm1d
(
128
),
torch
.
nn
.
ReLU
(
True
),
# 3
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
BatchNorm1d
(
128
),
torch
.
nn
.
ReLU
(
True
),
# 4
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
BatchNorm1d
(
128
),
torch
.
nn
.
ReLU
(
True
),
# bn
torch
.
nn
.
Linear
(
128
,
prior
.
input_size
),
torch
.
nn
.
BatchNorm1d
(
prior
.
input_size
),
torch
.
nn
.
ReLU
(
True
)
)
self
.
decoder
=
torch
.
nn
.
Sequential
(
# 5
torch
.
nn
.
Linear
(
prior
.
latent_size
,
128
),
torch
.
nn
.
BatchNorm1d
(
128
),
torch
.
nn
.
ReLU
(
True
),
# 6
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
BatchNorm1d
(
128
),
torch
.
nn
.
ReLU
(
True
),
# 7
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
BatchNorm1d
(
128
),
torch
.
nn
.
ReLU
(
True
),
# 8
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
BatchNorm1d
(
128
),
torch
.
nn
.
ReLU
(
True
),
# out
torch
.
nn
.
Linear
(
128
,
np
.
prod
(
input_shape
))
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
batch
):
batch
=
self
.
encode
(
batch
)
batch
=
self
.
prior
(
batch
)
batch
=
self
.
decode
(
batch
)
return
batch
def
encode
(
self
,
batch
):
x
=
batch
[
'observations'
]
x
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
batch
[
'pre_codes'
]
=
self
.
encoder
(
x
)
return
batch
def
decode
(
self
,
batch
):
batch
[
'pre_reconstructions'
]
=
self
.
decoder
(
batch
[
'codes'
]).
view
(
-
1
,
*
self
.
input_shape
)
batch
=
self
.
reconstruction
(
batch
)
return
batch
class
BaselineFCNN
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_shape
,
reconstruction_loss
):
super
().
__init__
()
self
.
input_shape
=
input_shape
self
.
reconstruction
=
reconstruction_loss