Commit 0fa40e24 authored by Verena Praher's avatar Verena Praher
Browse files

cleanup baseline model code

parent d430f4dd
......@@ -106,26 +106,10 @@ class CNN(pl.LightningModule):
return {'loss':self.my_loss(y_hat, y)}
def validation_step(self, data_batch, batch_nb):
# print("data_batch", data_batch)
x, _, y = data_batch
# print("x", x)
# print("y", y)
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
decisions = y_hat.t().cpu() > 0.5
decisions = decisions.type(torch.float)
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu())
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu())
_, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), decisions, average='micro')
return {'val_loss': self.my_loss(y_hat, y),
'val_rocauc':rocauc,
'val_prauc':prauc,
'val_fscore':fscore}
return validation_step(self, data_batch, batch_nb)
def test_step(self, data_batch, batch_nb):
return test_end(data_batch, batch_nb)
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
test_metrics = test_end(outputs)
......@@ -135,12 +119,6 @@ class CNN(pl.LightningModule):
def validation_end(self, outputs):
return validation_end(outputs)
def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
return test_end(outputs)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment