from utils import * import pytorch_lightning as pl from models.shared_stuff import * from sklearn import metrics # TODO pr-auc # TODO f1-score class CNN(pl.LightningModule): def __init__(self, num_class): super(CNN, self).__init__() # init bn self.bn_init = nn.BatchNorm2d(1) # layer 1 self.conv_1 = nn.Conv2d(1, 64, 3, padding=1) self.bn_1 = nn.BatchNorm2d(64) self.mp_1 = nn.MaxPool2d((2, 4)) # layer 2 self.conv_2 = nn.Conv2d(64, 128, 3, padding=1) self.bn_2 = nn.BatchNorm2d(128) self.mp_2 = nn.MaxPool2d((2, 4)) # layer 3 self.conv_3 = nn.Conv2d(128, 128, 3, padding=1) self.bn_3 = nn.BatchNorm2d(128) self.mp_3 = nn.MaxPool2d((2, 4)) # layer 4 self.conv_4 = nn.Conv2d(128, 128, 3, padding=1) self.bn_4 = nn.BatchNorm2d(128) self.mp_4 = nn.MaxPool2d((3, 5)) # layer 5 self.conv_5 = nn.Conv2d(128, 64, 3, padding=1) self.bn_5 = nn.BatchNorm2d(64) self.mp_5 = nn.MaxPool2d((4, 4)) # classifier self.dense = nn.Linear(320, num_class) self.dropout = nn.Dropout(0.5) def forward(self, x): # x = x.unsqueeze(1) # init bn x = self.bn_init(x) # print(x.shape) # layer 1 x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x)))) # print(x.shape) # layer 2 x = nn.ELU()(self.bn_2(self.conv_2(x))) # x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x)))) # print(x.shape) # layer 3 x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x)))) # print(x.shape) # layer 4 # x = nn.ELU()(self.bn_4(self.conv_4(x))) x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x)))) # print(x.shape) # layer 5 x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x)))) # print(x.shape) # classifier x = x.view(x.size(0), -1) # print("Lin input", x.shape) x = self.dropout(x) logit = nn.Sigmoid()(self.dense(x)) # print(x.shape) return logit def my_loss(self, y_hat, y): return my_loss(y_hat, y) def forward_full_song(self, x, y): # print(x.shape) #TODO full song??? return self.forward(x[:, :, :, :512]) # y_hat = torch.zeros((x.shape[0], 56), requires_grad=True).cuda() # hop_size = 256 # i=0 # count = 0 # while i < x.shape[-1]: # y_hat += self.forward(x[:,:,:,i:i+512]) # i += hop_size # count += 1 # return y_hat/count def training_step(self, data_batch, batch_nb): x, _, y = data_batch y_hat = self.forward_full_song(x, y) y = y.float() y_hat = y_hat.float() return {'loss':self.my_loss(y_hat, y)} def validation_step(self, data_batch, batch_nb): # print("data_batch", data_batch) x, _, y = data_batch # print("x", x) # print("y", y) y_hat = self.forward_full_song(x, y) y = y.float() y_hat = y_hat.float() decisions = y_hat.t().cpu() > 0.5 decisions = decisions.type(torch.float) rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu()) prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu()) _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), decisions, average='micro') return {'val_loss': self.my_loss(y_hat, y), 'val_rocauc':rocauc, 'val_prauc':prauc, 'val_fscore':fscore} def test_step(self, data_batch, batch_nb): return test_end(data_batch, batch_nb) def test_end(self, outputs): test_metrics = test_end(outputs) self.experiment.log(test_metrics) return test_metrics def validation_end(self, outputs): return validation_end(outputs) def test_step(self, data_batch, batch_nb): return test_step(self, data_batch, batch_nb) def test_end(self, outputs): return test_end(outputs) def configure_optimizers(self): return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code @pl.data_loader def tng_dataloader(self): return tng_dataloader() @pl.data_loader def val_dataloader(self): return val_dataloader() @pl.data_loader def test_dataloader(self): return test_dataloader() @staticmethod def add_model_specific_args(parent_parser): return parent_parser pass