Commit fd4a37e1 authored by Verena Praher's avatar Verena Praher

use self implemented ConcatDataset; take care of labels in loss fct

parent d442eabe
import bisect
from torch.utils.data import Dataset
import torch
class ConcatDataset(Dataset):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Arguments:
datasets (sequence): List of datasets to be concatenated
"""
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
#for d in self.datasets:
# assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
x, _, y = self.datasets[dataset_idx][sample_idx]
if dataset_idx==0:
z = torch.zeros(56)
z[:len(y)] = y
y = z / 10
else:
y = y.float()
return x, dataset_idx, y
if __name__ == '__main__':
from models.midlevel_mtg_vgg import get_mtg_sets, get_midlevel_sets
from torch.utils.data import DataLoader
midlevel_trainset, _, _ = get_midlevel_sets()
mtg_trainset, _, _ = get_mtg_sets()
trainset = ConcatDataset([midlevel_trainset, mtg_trainset])
dl = DataLoader(dataset=trainset, batch_size=32, shuffle=True)
batch = next(iter(dl))
x, di, y = batch
print("x", x)
print("di", di)
print("y", y)
......@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets.concat import ConcatDataset
import pytorch_lightning as pl
from datasets.midlevel import df_get_midlevel_set
from datasets.mtgjamendo import df_get_mtg_set
......@@ -47,11 +48,12 @@ def get_midlevel_sets():
class ModelMidlevel(BasePtlModel):
def __init__(self, config, hparams, num_targets, initialize=True, dataset='midlevel'):
def __init__(self, config, hparams, num_targets, initialize=True):
super(ModelMidlevel, self).__init__(config, hparams)
self.midlevel_trainset, self.midlevel_valset, self.midlevel_testset = get_midlevel_sets()
self.mtg_trainset, self.mtg_valset, self.mtg_testset = get_mtg_sets()
self.trainset = ConcatDataset([self.midlevel_trainset, self.mtg_trainset])
self.num_targets = num_targets
self.conv1 = nn.Sequential(
......@@ -124,9 +126,10 @@ class ModelMidlevel(BasePtlModel):
if initialize:
self.apply(initialize_weights)
self.fc_mtg1 = nn.Linear(256, 56)
self.fc_mtg1 = nn.Sequential(nn.Linear(256, 56), nn.Sigmoid())
def forward(self, x, dataset_i=None):
def forward(self, x, train_mode=False):
# print("x", x.shape)
# 313 * 149 * 1
x = self.conv1(x) # 157 * 75 * 64
x = self.conv2(x) # 157 * 75 * 64
......@@ -141,40 +144,62 @@ class ModelMidlevel(BasePtlModel):
x = self.conv11(x) # 2 * 2 * 256
x = x.view(x.size(0), -1)
if dataset_i == 0: # 0 - midlevel
x = self.fc_ml(x)
else: # 1 or None
x = nn.Sigmoid()(self.fc_mtg1(x))
return x
# if dataset_i is None:
# x = self.fc_mtg1(x)
# else:
# # 1 - dataset_i should cancel mtg samples
# ml = self.fc_ml(x)
# print("ml", ml.shape)
# mtg = self.fc_mtg1(x)
# print("mtg", mtg.shape)
# ml = (1 - dataset_i).unsqueeze(1).float() * ml
# print("ml", ml.shape)
# mtg = dataset_i.unsqueeze(1).float() * mtg
# print("mtg", mtg.shape)
# x = mtg
# x[dataset_i]
# x = ml + mtg
if not train_mode: # for testing and validation
return self.fc_mtg1(x)
return self.fc_ml(x), self.fc_mtg1(x)
def midlevel_loss(self, y_hat, y):
return F.mse_loss(y_hat, y)
# print("midlevel_loss y_hat;", y_hat.shape, "y", y.shape)
# print("Computing MSE loss between ", y_hat, y[:, :y_hat.shape[1]])
return F.mse_loss(y_hat, y[:, :y_hat.shape[1]])
def mtg_loss(self, y_hat, y):
return self.hparams.mtg_loss_weight * super(ModelMidlevel, self).loss(y_hat, y)
# print("mtg_loss y_hat;", y_hat.shape, "y", y.shape)
return F.binary_cross_entropy(y_hat, y)
# TODO - deal with the fact that Base Model does not have this parameter
# maybe not use base model for this model?
def training_step(self, data_batch, batch_i, dataset_i=None):
x, _, y = data_batch
y_hat = self.forward(x, dataset_i)
def training_step(self, data_batch, batch_i):
x, dataset_i, y = data_batch
y_ml, y_mtg = self.forward(x, train_mode=True)
y = y.float()
y_hat = y_hat.float()
if dataset_order[dataset_i] == 'midlevel':
return {'loss': self.midlevel_loss(y_hat, y)}
else:
return {'loss': self.mtg_loss(y_hat, y)}
y_ml, y_mtg = y_ml.float(), y_mtg.float()
ml_idx = torch.where(dataset_i == 0)
mtg_idx = torch.where(dataset_i == 1)
midlevel_loss = self.midlevel_loss(y_ml[ml_idx], y[ml_idx])
mtg_loss = self.mtg_loss(y_mtg[mtg_idx], y[mtg_idx])
# print("midlevel_loss", midlevel_loss)
# print("mtg_loss", mtg_loss)
return {'loss': midlevel_loss + self.hparams.mtg_loss_weight * mtg_loss}
def validation_step(self, data_batch, batch_i, dataset_i=None):
def validation_step(self, data_batch, batch_i):
# print("running midlevel_mtg_vgg::validation_step")
x, _, y = data_batch
y_hat = self.forward(x, dataset_i)
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
if dataset_i is not None and dataset_order[dataset_i] == 'midlevel':
val_loss = self.midlevel_loss(y_hat, y)
else:
val_loss = self.mtg_loss(y_hat, y)
# if dataset_i is not None and dataset_order[dataset_i] == 'midlevel':
# val_loss = self.midlevel_loss(y_hat, y)
# else:
val_loss = self.mtg_loss(y_hat, y)
return {'val_loss': val_loss,
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
......@@ -227,9 +252,9 @@ class ModelMidlevel(BasePtlModel):
@pl.data_loader
def tng_dataloader(self):
# TODO: deal with different dataset sizes - use more from mtg? ...
return [DataLoader(dataset=self.midlevel_trainset, batch_size=8, shuffle=True),
DataLoader(dataset=self.mtg_trainset, batch_size=24, shuffle=True)]
#return [DataLoader(dataset=self.midlevel_trainset, batch_size=32, shuffle=True),
# DataLoader(dataset=self.mtg_trainset, batch_size=32, shuffle=True)]
return DataLoader(dataset=self.trainset, batch_size=32, shuffle=True)
@pl.data_loader
def val_dataloader(self):
......@@ -256,7 +281,7 @@ class ModelMidlevel(BasePtlModel):
# tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True)
parser.opt_list('--mtg_loss_weight', default=1, options=[1, 2, 4, 8], type=int, tunable=True)
parser.opt_list('--mtg_loss_weight', default=1, options=[1, 2, 4, 8, 16, 32], type=int, tunable=True)
# training params (opt)
#parser.opt_list('--optimizer_name', default='adam', type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment