from utils import * import pytorch_lightning as pl from models.shared_stuff import * from sklearn import metrics # TODO pr-auc # TODO f1-score class CNN(pl.LightningModule): def __init__(self, num_class): super(CNN, self).__init__() # init bn self.bn_init = nn.BatchNorm2d(1) # layer 1 self.conv_1 = nn.Conv2d(1, 64, 3, padding=1) self.bn_1 = nn.BatchNorm2d(64) self.mp_1 = nn.MaxPool2d((2, 4)) # layer 2 self.conv_2 = nn.Conv2d(64, 128, 3, padding=1) self.bn_2 = nn.BatchNorm2d(128) self.mp_2 = nn.MaxPool2d((2, 4)) # layer 3 self.conv_3 = nn.Conv2d(128, 128, 3, padding=1) self.bn_3 = nn.BatchNorm2d(128) self.mp_3 = nn.MaxPool2d((2, 4)) # layer 4 self.conv_4 = nn.Conv2d(128, 128, 3, padding=1) self.bn_4 = nn.BatchNorm2d(128) self.mp_4 = nn.MaxPool2d((3, 5)) # layer 5 self.conv_5 = nn.Conv2d(128, 64, 3, padding=1) self.bn_5 = nn.BatchNorm2d(64) self.mp_5 = nn.MaxPool2d((4, 4)) # classifier self.dense = nn.Linear(320, num_class) self.dropout = nn.Dropout(0.5) def forward(self, x): # x = x.unsqueeze(1) x = x[:, :, :, :512] # init bn x = self.bn_init(x) # print(x.shape) # layer 1 x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x)))) # print(x.shape) # layer 2 x = nn.ELU()(self.bn_2(self.conv_2(x))) # x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x)))) # print(x.shape) # layer 3 x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x)))) # print(x.shape) # layer 4 # x = nn.ELU()(self.bn_4(self.conv_4(x))) x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x)))) # print(x.shape) # layer 5 x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x)))) # print(x.shape) # classifier x = x.view(x.size(0), -1) # print("Lin input", x.shape) x = self.dropout(x) logit = nn.Sigmoid()(self.dense(x)) # print(x.shape) return logit def my_loss(self, y_hat, y): return my_loss(y_hat, y) def training_step(self, data_batch, batch_nb): x, _, y = data_batch y_hat = self.forward_full_song(x, y) y = y.float() y_hat = y_hat.float() return {'loss':self.my_loss(y_hat, y)} def validation_step(self, data_batch, batch_nb): return validation_step(self, data_batch, batch_nb) def test_step(self, data_batch, batch_nb): return test_step(self, data_batch, batch_nb) def test_end(self, outputs): test_metrics = test_end(outputs) self.experiment.log(test_metrics) return test_metrics def validation_end(self, outputs): return validation_end(outputs) def configure_optimizers(self): return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code @pl.data_loader def tng_dataloader(self): return tng_dataloader() @pl.data_loader def val_dataloader(self): return val_dataloader() @pl.data_loader def test_dataloader(self): return test_dataloader() @staticmethod def add_model_specific_args(parent_parser): return parent_parser pass