Commit 49d7ad1e authored by Verena Praher's avatar Verena Praher

implement training_step and validation_step for joint training; add param for weighting loss fct

parent 6b962f89
...@@ -9,6 +9,7 @@ import pytorch_lightning as pl ...@@ -9,6 +9,7 @@ import pytorch_lightning as pl
from datasets.midlevel import df_get_midlevel_set from datasets.midlevel import df_get_midlevel_set
from datasets.mtgjamendo import df_get_mtg_set from datasets.mtgjamendo import df_get_mtg_set
dataset_order = ['midlevel', 'mtg']
def initialize_weights(module): def initialize_weights(module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
...@@ -154,30 +155,31 @@ class ModelMidlevel(BasePtlModel): ...@@ -154,30 +155,31 @@ class ModelMidlevel(BasePtlModel):
# TODO - deal with the fact that Base Model does not have this parameter # TODO - deal with the fact that Base Model does not have this parameter
# maybe not use base model for this model? # maybe not use base model for this model?
def training_step(self, data_batch, batch_nb, dataset_i): def training_step(self, data_batch, batch_i, dataset_i=None):
pass x, _, y = data_batch
# if self.dataset=='midlevel': y_hat = self.forward(x, dataset_i)
# x, _, y = data_batch y = y.float()
# y_hat = self.forward(x) y_hat = y_hat.float()
# y = y.float() if dataset_order[dataset_i] == 'midlevel':
# y_hat = y_hat.float() return {'loss': self.midlevel_loss(y_hat, y)}
# return {'loss':self.my_loss(y_hat, y)} else:
# else: return {'loss': self.mtg_loss(y_hat, y)}
# return super(ModelMidlevel, self).training_step(data_batch, batch_nb)
def validation_step(self, data_batch, batch_i, dataset_i=None):
def validation_step(self, data_batch, batch_nb, dataset_i): # print("running midlevel_mtg_vgg::validation_step")
pass x, _, y = data_batch
# if self.dataset=='midlevel': y_hat = self.forward(x, dataset_i)
# x, _, y = data_batch y = y.float()
# y_hat = self.forward(x) y_hat = y_hat.float()
# y = y.float() if dataset_i is not None and dataset_order[dataset_i] == 'midlevel':
# y_hat = y_hat.float() val_loss = self.midlevel_loss(y_hat, y)
# return {'val_loss': self.my_loss(y_hat, y), else:
# 'y': y.cpu().numpy(), val_loss = self.mtg_loss(y_hat, y)
# 'y_hat': y_hat.cpu().numpy(), return {'val_loss': val_loss,
# } 'y': y.cpu().numpy(),
# else: 'y_hat': y_hat.cpu().numpy(),
# return super(ModelMidlevel, self).validation_step(data_batch, batch_nb) 'prog': {'val_loss': val_loss}
}
# COMMENT: the following functions can probably be taken from base model # COMMENT: the following functions can probably be taken from base model
...@@ -231,13 +233,14 @@ class ModelMidlevel(BasePtlModel): ...@@ -231,13 +233,14 @@ class ModelMidlevel(BasePtlModel):
@pl.data_loader @pl.data_loader
def val_dataloader(self): def val_dataloader(self):
return [DataLoader(dataset=self.midlevel_valset, batch_size=8, shuffle=True), #return [DataLoader(dataset=self.midlevel_valset, batch_size=8, shuffle=True),
DataLoader(dataset=self.mtg_valset, batch_size=24, shuffle=True)] # DataLoader(dataset=self.mtg_valset, batch_size=24, shuffle=True)]
return DataLoader(dataset=self.mtg_valset, batch_size=32, shuffle=True)
@pl.data_loader @pl.data_loader
def test_dataloader(self): def test_dataloader(self):
# COMMENT: here I guess we are only interested in the mtg data # COMMENT: here I guess we are only interested in the mtg data
return DataLoader(dataset=self.testset, batch_size=32, shuffle=True) return DataLoader(dataset=self.mtg_testset, batch_size=32, shuffle=True)
@staticmethod @staticmethod
...@@ -253,7 +256,7 @@ class ModelMidlevel(BasePtlModel): ...@@ -253,7 +256,7 @@ class ModelMidlevel(BasePtlModel):
# tunable=True) # tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False) parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=False)
parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True) parser.opt_list('--input_size', default=1024, options=[512, 1024], type=int, tunable=True)
parser.opt_list('--mtg_loss_weight', default=1, options=[1, 2, 4, 8], type=int, tunable=True)
# training params (opt) # training params (opt)
#parser.opt_list('--optimizer_name', default='adam', type=str, #parser.opt_list('--optimizer_name', default='adam', type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment