Commit 49d7ad1e authored by Verena Praher's avatar Verena Praher

implement training_step and validation_step for joint training; add param for weighting loss fct

parent 6b962f89
......@@ -9,6 +9,7 @@ import pytorch_lightning as pl
from datasets.midlevel import df_get_midlevel_set
from datasets.mtgjamendo import df_get_mtg_set
dataset_order = ['midlevel', 'mtg']
def initialize_weights(module):
if isinstance(module, nn.Conv2d):
......@@ -154,30 +155,31 @@ class ModelMidlevel(BasePtlModel):
# TODO - deal with the fact that Base Model does not have this parameter
# maybe not use base model for this model?
def training_step(self, data_batch, batch_nb, dataset_i):
pass
# if self.dataset=='midlevel':
# x, _, y = data_batch
# y_hat = self.forward(x)
# y = y.float()
# y_hat = y_hat.float()
# return {'loss':self.my_loss(y_hat, y)}
# else:
# return super(ModelMidlevel, self).training_step(data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb, dataset_i):
pass
# if self.dataset=='midlevel':
# x, _, y = data_batch
# y_hat = self.forward(x)
# y = y.float()
# y_hat = y_hat.float()
# return {'val_loss': self.my_loss(y_hat, y),
# 'y': y.cpu().numpy(),
# 'y_hat': y_hat.cpu().numpy(),
# }
# else:
# return super(ModelMidlevel, self).validation_step(data_batch, batch_nb)
def training_step(self, data_batch, batch_i, dataset_i=None):
x, _, y = data_batch
y_hat = self.forward(x, dataset_i)
y = y.float()
y_hat = y_hat.float()
if dataset_order[dataset_i] == 'midlevel':
return {'loss': self.midlevel_loss(y_hat, y)}
else:
return {'loss': self.mtg_loss(y_hat, y)}
def validation_step(self, data_batch, batch_i, dataset_i=None):
# print("running midlevel_mtg_vgg::validation_step")
x, _, y = data_batch
y_hat = self.forward(x, dataset_i)
y = y.float()
y_hat = y_hat.float()
if dataset_i is not None and dataset_order[dataset_i] == 'midlevel':
val_loss = self.midlevel_loss(y_hat, y)
else:
val_loss = self.mtg_loss(y_hat, y)
return {'val_loss': val_loss,
'y': y.cpu().numpy(),
'y_hat': y_hat.cpu().numpy(),
'prog': {'val_loss': val_loss}
}
# COMMENT: the following functions can probably be taken from base model
......@@ -231,13 +233,14 @@ class ModelMidlevel(BasePtlModel):
@pl.data_loader
def val_dataloader(self):
return [DataLoader(dataset=self.midlevel_valset, batch_size=8, shuffle=True),
DataLoader(dataset=self.mtg_valset, batch_size=24, shuffle=True)]
#return [DataLoader(dataset=self.midlevel_valset, batch_size=8, shuffle=True),
# DataLoader(dataset=self.mtg_valset, batch_size=24, shuffle=True)]
return DataLoader(dataset=self.mtg_valset, batch_size=32, shuffle=True)
@pl.data_loader
def test_dataloader(self):
# COMMENT: here I guess we are only interested in the mtg data
return DataLoader(dataset=self.testset, batch_size=32, shuffle=True)
return DataLoader(dataset=self.mtg_testset, batch_size=32, shuffle=True)
@staticmethod
......@@ -253,7 +256,7 @@ class ModelMidlevel(BasePtlModel):
# tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True)
parser.opt_list('--mtg_loss_weight', default=1, options=[1, 2, 4, 8], type=int, tunable=True)
# training params (opt)
#parser.opt_list('--optimizer_name', default='adam', type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment