Commit 638b7fda authored by Verena Praher's avatar Verena Praher

setup model for joint training (open todos)

parent 16cdda9f
from models.shared_stuff import BasePtlModel
from test_tube import HyperOptArgumentParser
from utils import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from datasets.midlevel import df_get_midlevel_set
from datasets.mtgjamendo import df_get_mtg_set
def initialize_weights(module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight.data, mode='fan_in', nonlinearity="relu")
# nn.init.kaiming_normal_(module.weight.data, mode='fan_out')
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
module.bias.data.zero_()
def get_midlevel_sets():
data_root, audio_path, csvs_path = get_paths('midlevel')
cache_x_name = '_ap_midlevel44k'
from torch.utils.data import random_split
dataset, dataset_length = df_get_midlevel_set('midlevel', os.path.join(csvs_path, 'annotations.csv'), audio_path,
cache_x_name)
trainset, validationset, testset = random_split(dataset,
[int(i * dataset_length) for i in [0.7, 0.2, 0.1]])
return trainset, validationset, testset
class ModelMidlevel(BasePtlModel):
def __init__(self, config, hparams, num_targets, initialize=True, dataset='midlevel'):
super(ModelMidlevel, self).__init__(config, hparams)
self.midlevel_trainset, self.midlevel_valset, self.midlevel_testset = get_midlevel_sets()
# TODO same for mtg sets
self.num_targets = num_targets
self.conv1 = nn.Sequential(
nn.Conv2d(1, 64, 5, 2, 2), # (in_channels, out_channels, kernel_size, stride, padding)
nn.BatchNorm2d(64),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.mp2x2_dropout = nn.Sequential(
nn.MaxPool2d(2),
nn.Dropout2d(0.3)
)
self.ap2x2_dropout = nn.Sequential(
nn.AvgPool2d(2),
nn.Dropout2d(0.3)
)
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv4 = nn.Sequential(
nn.Conv2d(128, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv5 = nn.Sequential(
nn.Conv2d(128, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv6 = nn.Sequential(
nn.Conv2d(256, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv7 = nn.Sequential(
nn.Conv2d(256, 384, 3, 1, 1),
nn.BatchNorm2d(384),
nn.ReLU()
)
self.conv7b = nn.Sequential(
nn.Conv2d(384, 512, 3, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.conv11 = nn.Sequential(
nn.Conv2d(512, 256, 1, 1, 0),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1))
)
self.fc_ml = nn.Linear(256, 7)
if initialize:
self.apply(initialize_weights)
#if load_from:
# self._load_model(load_from, map_location, on_gpu)
# if dataset == 'mtgjamendo':
self.fc_mtg1 = nn.Linear(256, 56)
# self.fc_mtg2 = nn.Linear(10, 56)
# for name, param in self.named_parameters():
# if 'mtg' in name:
# param.requires_grad = True
# else:
# param.requires_grad = False
def forward(self, x, dataset_i=None):
# 313 * 149 * 1
x = self.conv1(x) # 157 * 75 * 64
x = self.conv2(x) # 157 * 75 * 64
x = self.ap2x2_dropout(x) # 78 * 37 * 64
x = self.conv3(x) # 78 * 37 * 128
x = self.conv4(x) # 78 * 37 * 128
x = self.ap2x2_dropout(x) # 39 * 18 * 128
x = self.conv5(x) # 39 * 18 * 256
x = self.conv6(x) # 39 * 18 * 256
x = self.conv7(x) # 39 * 18 * 384
x = self.conv7b(x) # 39 * 18 * 384
x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1)
if dataset_i == 0: # 0 - midlevel
x = self.fc_ml(x)
else: # 1 or None
x = nn.Sigmoid()(self.fc_mtg1(x))
return x
def midlevel_loss(self, y_hat, y):
return F.mse_loss(y_hat, y)
def mtg_loss(self, y_hat, y):
return self.hparams.mtg_loss_weight * super(ModelMidlevel, self).loss(y_hat, y)
# TODO - deal with the fact that Base Model does not have this parameter
# maybe not use base model for this model?
def training_step(self, data_batch, batch_nb, dataset_i):
pass
# if self.dataset=='midlevel':
# x, _, y = data_batch
# y_hat = self.forward(x)
# y = y.float()
# y_hat = y_hat.float()
# return {'loss':self.my_loss(y_hat, y)}
# else:
# return super(ModelMidlevel, self).training_step(data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb, dataset_i):
pass
# if self.dataset=='midlevel':
# x, _, y = data_batch
# y_hat = self.forward(x)
# y = y.float()
# y_hat = y_hat.float()
# return {'val_loss': self.my_loss(y_hat, y),
# 'y': y.cpu().numpy(),
# 'y_hat': y_hat.cpu().numpy(),
# }
# else:
# return super(ModelMidlevel, self).validation_step(data_batch, batch_nb)
# COMMENT: the following functions can probably be taken from base model
# def validation_end(self, outputs):
# if self.dataset=='midlevel':
# avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
# y = []
# y_hat = []
# for output in outputs:
# y.append(output['y'])
# y_hat.append(output['y_hat'])
#
# y = np.concatenate(y)
# y_hat = np.concatenate(y_hat)
#
# return {'val_loss': avg_loss}
# else:
# return super(ModelMidlevel, self).validation_end(outputs)
# def test_step(self, data_batch, batch_nb):
# if self.dataset == 'midlevel':
# x, _, y = data_batch
# y_hat = self.forward(x)
# y = y.float()
# y_hat = y_hat.float()
# return {'test_loss': self.my_loss(y_hat, y),
# 'y': y.cpu().numpy(),
# 'y_hat': y_hat.cpu().numpy(),
# }
# else:
# return super(ModelMidlevel, self).test_step(data_batch,batch_nb)
# def test_end(self, outputs):
# if self.dataset == 'midlevel':
# avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
# test_metrics = {"test_loss":avg_test_loss}
# self.experiment.log(test_metrics)
# return test_metrics
# else:
# return super(ModelMidlevel, self).test_end(outputs)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
@pl.data_loader
def tng_dataloader(self):
# TODO: deal with different dataset sizes - use more from mtg? ...
return [DataLoader(dataset=self.midlevel_trainset, batch_size=8, shuffle=True),
DataLoader(dataset=self.mtg_trainset, batch_size=24, shuffle=True)]
@pl.data_loader
def val_dataloader(self):
return [DataLoader(dataset=self.midlevel_valset, batch_size=8, shuffle=True),
DataLoader(dataset=self.mtg_valset, batch_size=24, shuffle=True)]
@pl.data_loader
def test_dataloader(self):
# COMMENT: here I guess we are only interested in the mtg data
return DataLoader(dataset=self.testset, batch_size=32, shuffle=True)
@staticmethod
def add_model_specific_args(parent_parser):
"""Parameters defined here will be available to your model through self.hparams
"""
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
# network params
# parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=True)
#parser.opt_list('--learning_rate', default=0.0001, type=float,
# options=[0.00001, 0.0005, 0.001],
# tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True)
# training params (opt)
#parser.opt_list('--optimizer_name', default='adam', type=str,
# options=['adam'], tunable=False)
# if using 2 nodes with 4 gpus each the batch size here
# (256) will be 256 / (2*8) = 16 per gpu
#parser.opt_list('--batch_size', default=32, type=int,
# options=[16, 32], tunable=False,
# help='batch size will be divided over all gpus being used across all nodes')
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment