Commit c95586f5 authored by Paul Primus's avatar Paul Primus
Browse files

made experiments

parent 3b06d542
......@@ -101,18 +101,17 @@ def configuration():
context = 8
model_class = 'dcase2020_task2.models.ConvAE' # 'dcase2020_task2.models.MADE'
hidden_size = 1024
num_hidden = 4
hidden_size = 256
num_hidden = 1
latent_size = 8 # only used for AEs
debug = False
if debug:
epochs = 1
num_workers = 0
else:
epochs = 100
num_workers = 1
num_workers = 4
epochs = 100
reconstruction_class = 'dcase2020_task2.losses.MSEReconstruction' # 'dcase2020_task2.losses.NLLReconstruction'
batch_size = 256
learning_rate = 1e-3
......
......@@ -5,7 +5,7 @@ import numpy as np
import torch
from dcase2020_task2.models.custom import activation_dict, init_weights
from torchsummary import summary
class AE(torch.nn.Module, VAEBase):
......@@ -16,7 +16,7 @@ class AE(torch.nn.Module, VAEBase):
prior=None,
hidden_size=128,
num_hidden=3,
activation='relu',
activation='elu',
batch_norm=False
):
super().__init__()
......@@ -72,14 +72,15 @@ class AE(torch.nn.Module, VAEBase):
return batch
class ResidualBlock(torch.nn.Module):
class ConvlBlock(torch.nn.Module):
def __init__(
self,
n_units,
n_layers=2,
kernel_size=(3, 3),
activation='relu'
activation='relu',
batch_norm=True
):
super().__init__()
......@@ -93,11 +94,12 @@ class ResidualBlock(torch.nn.Module):
padding=(kernel_size[0]//2, kernel_size[1]//2)
)
)
modules.append(
torch.nn.BatchNorm2d(
n_units
if batch_norm:
modules.append(
torch.nn.BatchNorm2d(
n_units
)
)
)
modules.append(
activation_dict[activation]()
)
......@@ -109,7 +111,7 @@ class ResidualBlock(torch.nn.Module):
)
def forward(self, x):
x = self.block(x) + x
x = self.block(x) # + x
return self.last_activation(x)
......@@ -121,8 +123,9 @@ class ConvAE(torch.nn.Module, VAEBase):
reconstruction_loss,
prior=None,
hidden_size=128,
num_hidden=3,
activation='relu'
num_hidden=1,
activation='relu',
batch_norm=False
):
super().__init__()
......@@ -133,65 +136,81 @@ class ConvAE(torch.nn.Module, VAEBase):
self.prior = prior
self.reconstruction = reconstruction_loss
self.input = torch.nn.Conv2d(
input_shape[0],
hidden_size,
kernel_size=1
)
input = [
torch.nn.Conv2d(
input_shape[0],
hidden_size,
kernel_size=1
)
]
if batch_norm:
input.append(
torch.nn.BatchNorm2d(hidden_size)
)
input.append(activation_fn())
self.input = torch.nn.Sequential(*input)
self.block1 = ResidualBlock(
self.block1 = ConvlBlock(
hidden_size,
n_layers=1,
n_layers=num_hidden,
kernel_size=(3, 3),
activation='relu'
activation=activation,
batch_norm=batch_norm
)
self.pool1 = torch.nn.MaxPool2d(2, return_indices=True)
self.pool1 = torch.nn.AvgPool2d(2)
self.block2 = ResidualBlock(
self.block2 = ConvlBlock(
hidden_size,
n_layers=1,
n_layers=num_hidden,
kernel_size=(3, 3),
activation='relu'
activation=activation,
batch_norm=batch_norm
)
self.pool2 = torch.nn.MaxPool2d(2, return_indices=True)
self.pool2 = torch.nn.AvgPool2d(2)
self.block3 = ResidualBlock(
self.block3 = ConvlBlock(
hidden_size,
n_layers=1,
n_layers=num_hidden,
kernel_size=(3, 3),
activation='relu'
activation=activation,
batch_norm=batch_norm
)
self.pool3 = torch.nn.MaxPool2d(2, return_indices=True)
self.pool3 = torch.nn.AvgPool2d(2)
self.pre_pool_size = hidden_size * input_shape[1] // 8 * input_shape[2] // 8
pre_hidden_size = hidden_size * input_shape[1]//8 * input_shape[2]//8
self.pre_prior = torch.nn.Sequential(
torch.nn.Linear(pre_hidden_size, self.prior.input_size)
torch.nn.Linear(self.pre_pool_size, self.prior.input_size)
)
self.post_prior = torch.nn.Sequential(
torch.nn.Linear(self.prior.latent_size, pre_hidden_size),
torch.nn.Linear(self.prior.latent_size, self.pre_pool_size),
activation_fn()
)
self.block4 = ResidualBlock(
self.block4 = ConvlBlock(
hidden_size,
n_layers=1,
n_layers=num_hidden,
kernel_size=(3, 3),
activation='relu'
activation=activation,
batch_norm=batch_norm
)
self.block5 = ResidualBlock(
self.block5 = ConvlBlock(
hidden_size,
n_layers=1,
n_layers=num_hidden,
kernel_size=(3, 3),
activation='relu'
activation=activation,
batch_norm=batch_norm
)
self.block6 = ResidualBlock(
self.block6 = ConvlBlock(
hidden_size,
n_layers=1,
n_layers=num_hidden,
kernel_size=(3, 3),
activation='relu'
activation=activation,
batch_norm=batch_norm
)
self.output = torch.nn.Conv2d(
......@@ -214,17 +233,21 @@ class ConvAE(torch.nn.Module, VAEBase):
def encode(self, batch):
x = batch['observations']
x = self.input(x)
x, idx1 = self.pool1(self.block1(x))
x, idx2 = self.pool2(self.block2(x))
x, idx3 = self.pool3(self.block3(x))
x = self.pool1(self.block1(x))
x = self.pool2(self.block2(x))
x = self.pool3(self.block3(x))
batch['pre_codes'] = x
batch['pool_indices'] = [idx1, idx2, idx3]
# batch['pool_indices'] = [idx1, idx2, idx3]
return batch
def decode(self, batch):
x = torch.nn.functional.max_unpool2d(batch['post_codes'], batch['pool_indices'][2], 2)
x = torch.nn.functional.max_unpool2d(self.block4(x), batch['pool_indices'][1], 2)
x = torch.nn.functional.max_unpool2d(self.block5(x), batch['pool_indices'][0], 2)
# x = torch.nn.functional.max_unpool2d(batch['post_codes'], batch['pool_indices'][2], self.pre_pool_size)
# x = torch.nn.functional.max_unpool2d(self.block4(x), batch['pool_indices'][1], 2)
# x = torch.nn.functional.max_unpool2d(self.block5(x), batch['pool_indices'][0], 2)
x = torch.nn.functional.interpolate(batch['post_codes'], scale_factor=(2, 2))
x = torch.nn.functional.interpolate(self.block4(x), scale_factor=(2, 2))
x = torch.nn.functional.interpolate(self.block5(x), scale_factor=(2, 2))
batch['pre_reconstructions'] = self.output(self.block6(x))
batch = self.reconstruction(batch)
return batch
......
......@@ -7,7 +7,9 @@ from torch.nn import functional as F
activation_dict = {
'relu': torch.nn.ReLU,
'tanh': torch.nn.Tanh
'tanh': torch.nn.Tanh,
'elu': torch.nn.ELU,
'prelu': torch.nn.PReLU
}
def init_weights(m):
......
conda activate dcase2020_task2
./scripts/run_all.sh baseline_experiment 0 "debug=False num_hidden=1 hidden_size=128 latent_size=4 weight_decay=0 model_class=dcase2020_task2.models.ConvAE reconstruction_class=dcase2020_task2.losses.MSEReconstruction -m student2.cp.jku.at:27017:dcase2020_task2_ae_baseline_gridsearch" &
./scripts/run_all.sh baseline_experiment 1 "debug=False num_hidden=1 hidden_size=128 latent_size=4 weight_decay=1e-5 model_class=dcase2020_task2.models.ConvAE reconstruction_class=dcase2020_task2.losses.MSEReconstruction -m student2.cp.jku.at:27017:dcase2020_task2_ae_baseline_gridsearch" &
./scripts/run_all.sh baseline_experiment 2 "debug=False num_hidden=1 hidden_size=128 latent_size=8 weight_decay=0 model_class=dcase2020_task2.models.ConvAE reconstruction_class=dcase2020_task2.losses.MSEReconstruction -m student2.cp.jku.at:27017:dcase2020_task2_ae_baseline_gridsearch" &
./scripts/run_all.sh baseline_experiment 3 "debug=False num_hidden=1 hidden_size=128 latent_size=8 weight_decay=1e-5 model_class=dcase2020_task2.models.ConvAE reconstruction_class=dcase2020_task2.losses.MSEReconstruction -m student2.cp.jku.at:27017:dcase2020_task2_ae_baseline_gridsearch" &
./scripts/run_all.sh baseline_experiment 4 "debug=False num_hidden=1 hidden_size=256 latent_size=4 weight_decay=0 model_class=dcase2020_task2.models.ConvAE reconstruction_class=dcase2020_task2.losses.MSEReconstruction -m student2.cp.jku.at:27017:dcase2020_task2_ae_baseline_gridsearch" &
./scripts/run_all.sh baseline_experiment 5 "debug=False num_hidden=1 hidden_size=256 latent_size=4 weight_decay=1e-5 model_class=dcase2020_task2.models.ConvAE reconstruction_class=dcase2020_task2.losses.MSEReconstruction -m student2.cp.jku.at:27017:dcase2020_task2_ae_baseline_gridsearch" &
./scripts/run_all.sh baseline_experiment 6 "debug=False num_hidden=1 hidden_size=256 latent_size=8 weight_decay=0 model_class=dcase2020_task2.models.ConvAE reconstruction_class=dcase2020_task2.losses.MSEReconstruction -m student2.cp.jku.at:27017:dcase2020_task2_ae_baseline_gridsearch" &
./scripts/run_all.sh baseline_experiment 7 "debug=False num_hidden=1 hidden_size=256 latent_size=8 weight_decay=1e-5 model_class=dcase2020_task2.models.ConvAE reconstruction_class=dcase2020_task2.losses.MSEReconstruction -m student2.cp.jku.at:27017:dcase2020_task2_ae_baseline_gridsearch" &
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment