Commit 6b962f89 authored by Verena Praher's avatar Verena Praher

add dataset_i=None to allow multiple dataloaders in derived classes

parent d671e94b
......@@ -81,14 +81,15 @@ class BasePtlModel(pl.LightningModule):
checkpoint = torch.load(last_ckpt_path, map_location=lambda storage, loc: storage)
self.load_state_dict(checkpoint['state_dict'])
def training_step(self, data_batch, batch_i):
def training_step(self, data_batch, batch_i, dataset_i=None):
x, _, y = data_batch
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'loss': self.loss(y_hat, y)}
def validation_step(self, data_batch, batch_i):
def validation_step(self, data_batch, batch_i, dataset_i=None):
print("running BasePtlModel::validation_step")
x, _, y = data_batch
y_hat = self.forward(x)
y = y.float()
......@@ -99,7 +100,7 @@ class BasePtlModel(pl.LightningModule):
'y_hat': y_hat.cpu().numpy()
}
def test_step(self, data_batch, batch_i):
def test_step(self, data_batch, batch_i, dataset_i=None):
x, _, y = data_batch
y_hat = self.forward(x)
y = y.float()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment