Commit e2682ad2 authored by Verena Praher's avatar Verena Praher

Adapt resnet to use BasePtlModel

parent ee66390f
......@@ -2,7 +2,7 @@ from utils import CURR_RUN_PATH, USE_GPU, logger, init_experiment
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from test_tube import Experiment, HyperOptArgumentParser
from models.resnet18 import Network
from models.resnet import Network
import os
......@@ -33,7 +33,8 @@ def run(hparams):
trainer = Trainer(gpus=[0], distributed_backend=None,
experiment=exp, max_nb_epochs=500, train_percent_check=1.0,
fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback)
checkpoint_callback=checkpoint_callback,
nb_sanity_val_steps=0)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
......
......@@ -2,21 +2,18 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from models.shared_stuff import tng_dataloader, val_dataloader, test_dataloader, \
validation_end, training_step, validation_step, test_step, test_end
from models.shared_stuff import BasePtlModel, base_model_config
from test_tube import HyperOptArgumentParser
from sklearn.metrics import roc_auc_score
# TODO pr-auc
# TODO f1-score
from models.resnet_arch import ResNet, BasicBlock, Bottleneck
class Network(pl.LightningModule):
class Network(BasePtlModel):
def __init__(self, hparams, num_tags):
super(Network, self).__init__()
super(Network, self).__init__(base_model_config, hparams)
self.num_tags = num_tags
self.hparams = hparams
self.arch = hparams.arch
......@@ -35,7 +32,6 @@ class Network(pl.LightningModule):
self.model = nn.Sequential(
ResNet(blocktype, layers, num_classes=self.num_tags),
nn.Sigmoid())
# TODO: need to check if optimizer recognizes these parameters
# num_features = self.model.fc.in_features
# self.model.fc = nn.Linear(num_features, self.num_tags) # overwriting fc layer
# self.sig = nn.Sigmoid()
......@@ -51,35 +47,6 @@ class Network(pl.LightningModule):
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)]
def training_step(self, data_batch, batch_nb):
return training_step(self, data_batch, batch_nb)
def validation_step(self, data_batch, batch_nb):
return validation_step(self, data_batch, batch_nb)
def validation_end(self, outputs):
return validation_end(outputs)
def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
test_metrics = test_end(outputs)
self.experiment.log(test_metrics)
return test_metrics
@pl.data_loader
def tng_dataloader(self):
return tng_dataloader(self.hparams.batch_size)
@pl.data_loader
def val_dataloader(self):
return val_dataloader(self.hparams.batch_size)
@pl.data_loader
def test_dataloader(self):
return test_dataloader(self.hparams.batch_size)
@staticmethod
def add_model_specific_args(parent_parser):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
......@@ -89,4 +56,6 @@ class Network(pl.LightningModule):
parser.opt_list('--learning_rate', default=0.0001, type=float,
options=[0.0001, 0.0005, 0.001],
tunable=True)
parser.opt_list('--slicing_mode', default='slice', options=['full', 'slice'], type=str, tunable=True)
parser.opt_list('--input_size', default=512, options=[512, 1024], type=int, tunable=False)
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment