Commit eb2252a1 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

Merge branch 'master' of gitlab.cp.jku.at:shreyan/moodwalk

parents 4881ccd5 af3e8a7d
...@@ -13,6 +13,7 @@ config = { ...@@ -13,6 +13,7 @@ config = {
'earlystopping_mode': 'min' # 'max' 'earlystopping_mode': 'min' # 'max'
} }
def epochs_500(): def epochs_500():
global config global config
config['epochs'] = 500 config['epochs'] = 500
......
from utils import CURR_RUN_PATH, USE_GPU, logger from utils import CURR_RUN_PATH, USE_GPU, logger, init_experiment
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from test_tube import Experiment from test_tube import Experiment, HyperOptArgumentParser
from models.resnet18 import Network from models.resnet18 import Network
import os import os
def run(): def run(hparams):
# init_experiment(comment=hparams.experiment_name)
init_experiment()
from utils import CURR_RUN_PATH, logger # import these after init_experiment
logger.info(CURR_RUN_PATH) logger.info(CURR_RUN_PATH)
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(save_dir=CURR_RUN_PATH) exp = Experiment(save_dir=CURR_RUN_PATH)
# TODO other training parameters?
# callbacks # callbacks
early_stop = EarlyStopping( early_stop = EarlyStopping(
monitor='val_loss', monitor='val_loss',
...@@ -29,7 +31,7 @@ def run(): ...@@ -29,7 +31,7 @@ def run():
) )
if USE_GPU: if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp', trainer = Trainer(gpus=[0], distributed_backend=None,
experiment=exp, max_nb_epochs=100, train_percent_check=1.0, experiment=exp, max_nb_epochs=100, train_percent_check=1.0,
fast_dev_run=False, early_stop_callback=early_stop, fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback) checkpoint_callback=checkpoint_callback)
...@@ -37,7 +39,7 @@ def run(): ...@@ -37,7 +39,7 @@ def run():
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1, trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True) fast_dev_run=True)
model = Network(56) # TODO num_tags model = Network(hparams, 56) # TODO num_tags
print(model) print(model)
...@@ -46,4 +48,13 @@ def run(): ...@@ -46,4 +48,13 @@ def run():
if __name__=='__main__': if __name__=='__main__':
run() parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
\ No newline at end of file #parent_parser.add_argument('--gpus', type=list, default=[0,1],
# help='how many gpus to use in the node.'
# ' value -1 uses all the gpus on the node')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
# run(hyperparams)
#gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5)
run(hyperparams)
...@@ -4,21 +4,36 @@ import torch.nn.functional as F ...@@ -4,21 +4,36 @@ import torch.nn.functional as F
import pytorch_lightning as pl import pytorch_lightning as pl
from models.shared_stuff import tng_dataloader, val_dataloader, test_dataloader, \ from models.shared_stuff import tng_dataloader, val_dataloader, test_dataloader, \
validation_end, training_step, validation_step, test_step, test_end validation_end, training_step, validation_step, test_step, test_end
from test_tube import HyperOptArgumentParser
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
# TODO pr-auc # TODO pr-auc
# TODO f1-score # TODO f1-score
from models.resnet_arch import ResNet, BasicBlock from models.resnet_arch import ResNet, BasicBlock, Bottleneck
class Network(pl.LightningModule): class Network(pl.LightningModule):
def __init__(self, num_tags): def __init__(self, hparams, num_tags):
super(Network, self).__init__() super(Network, self).__init__()
self.num_tags = num_tags self.num_tags = num_tags
self.hparams = hparams
self.arch = hparams.arch
layers = [2, 2, 2, 2]
blocktype = BasicBlock
if self.arch == 'resnet34':
layers = [3, 4, 6, 3]
elif self.arch == 'resnet50':
layers = [3, 4, 6, 3]
blocktype = Bottleneck
elif self.arch == 'resnet101':
layers = [3, 4, 23, 3]
blocktype = Bottleneck
self.model = nn.Sequential( self.model = nn.Sequential(
ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.num_tags), ResNet(blocktype, layers, num_classes=self.num_tags),
nn.Sigmoid()) nn.Sigmoid())
# TODO: need to check if optimizer recognizes these parameters # TODO: need to check if optimizer recognizes these parameters
# num_features = self.model.fc.in_features # num_features = self.model.fc.in_features
...@@ -34,7 +49,7 @@ class Network(pl.LightningModule): ...@@ -34,7 +49,7 @@ class Network(pl.LightningModule):
return F.binary_cross_entropy(y_hat, y) return F.binary_cross_entropy(y_hat, y)
def configure_optimizers(self): def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.001)] return [torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)]
def training_step(self, data_batch, batch_nb): def training_step(self, data_batch, batch_nb):
return training_step(self, data_batch, batch_nb) return training_step(self, data_batch, batch_nb)
...@@ -55,12 +70,23 @@ class Network(pl.LightningModule): ...@@ -55,12 +70,23 @@ class Network(pl.LightningModule):
@pl.data_loader @pl.data_loader
def tng_dataloader(self): def tng_dataloader(self):
return tng_dataloader() return tng_dataloader(self.hparams.batch_size)
@pl.data_loader @pl.data_loader
def val_dataloader(self): def val_dataloader(self):
return val_dataloader() return val_dataloader(self.hparams.batch_size)
@pl.data_loader @pl.data_loader
def test_dataloader(self): def test_dataloader(self):
return test_dataloader() return test_dataloader(self.hparams.batch_size)
@staticmethod
def add_model_specific_args(parent_parser):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
parser.add_argument('--batch_size', default=32, type=int)
parser.opt_list('--arch', default='resnet18', type=str, tunable=True,
options=['resnet18', 'resnet34', 'resnet50', 'resnet101'])
parser.opt_list('--learning_rate', default=0.0001, type=float,
options=[0.0001, 0.0005, 0.001],
tunable=True)
return parser
...@@ -117,28 +117,28 @@ def test_end(outputs): ...@@ -117,28 +117,28 @@ def test_end(outputs):
'test_fscore': fscore} 'test_fscore': fscore}
def tng_dataloader(): def tng_dataloader(batch_size=32):
train_csv = os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv') train_csv = os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k" cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo', train_csv, PATH_AUDIO, cache_x_name) dataset = df_get_mtg_set('mtgjamendo', train_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset, return DataLoader(dataset=dataset,
batch_size=32, batch_size=batch_size,
shuffle=True) shuffle=True)
def val_dataloader(): def val_dataloader(batch_size=32):
validation_csv = os.path.join(PATH_ANNOTATIONS, 'validation_processed.tsv') validation_csv = os.path.join(PATH_ANNOTATIONS, 'validation_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k" cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_val', validation_csv, PATH_AUDIO, cache_x_name) dataset = df_get_mtg_set('mtgjamendo_val', validation_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset, return DataLoader(dataset=dataset,
batch_size=32, batch_size=batch_size,
shuffle=True) shuffle=True)
def test_dataloader(): def test_dataloader(batch_size=32):
test_csv = os.path.join(PATH_ANNOTATIONS, 'test_processed.tsv') test_csv = os.path.join(PATH_ANNOTATIONS, 'test_processed.tsv')
cache_x_name = "_ap_mtgjamendo44k" cache_x_name = "_ap_mtgjamendo44k"
dataset = df_get_mtg_set('mtgjamendo_test', test_csv, PATH_AUDIO, cache_x_name) dataset = df_get_mtg_set('mtgjamendo_test', test_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset, return DataLoader(dataset=dataset,
batch_size=32, batch_size=batch_size,
shuffle=True) shuffle=True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment