Commit 615f22df authored by Verena Praher's avatar Verena Praher

dataset_i is not needed, remove for simplicity

parent fd4a37e1
......@@ -81,14 +81,14 @@ class BasePtlModel(pl.LightningModule):
checkpoint = torch.load(last_ckpt_path, map_location=lambda storage, loc: storage)
self.load_state_dict(checkpoint['state_dict'])
def training_step(self, data_batch, batch_i, dataset_i=None):
def training_step(self, data_batch, batch_i):
x, _, y = data_batch
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'loss': self.loss(y_hat, y)}
def validation_step(self, data_batch, batch_i, dataset_i=None):
def validation_step(self, data_batch, batch_i):
x, _, y = data_batch
y_hat = self.forward(x)
y = y.float()
......@@ -99,7 +99,7 @@ class BasePtlModel(pl.LightningModule):
'y_hat': y_hat.cpu().numpy()
}
def test_step(self, data_batch, batch_i, dataset_i=None):
def test_step(self, data_batch, batch_i):
x, _, y = data_batch
y_hat = self.forward(x)
y = y.float()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment