Commit 8e66a220 authored by Verena Praher's avatar Verena Praher
Browse files

merge

parents a54119c7 78ef5ebb
from utils import CURR_RUN_PATH, USE_GPU, logger
from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from test_tube import Experiment from test_tube import Experiment, HyperOptArgumentParser
from models.baseline import CNN as Network from models.baseline import CNN as Network
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os import os
def run(): config = {
'epochs':1
}
def epochs_500():
global config
config['epochs'] = 500
def run(hparams):
init_experiment(comment=hparams.experiment_name)
from utils import CURR_RUN_PATH, logger # import these after init_experiment
logger.info(CURR_RUN_PATH) logger.info(CURR_RUN_PATH)
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(save_dir=CURR_RUN_PATH) exp = Experiment(save_dir=CURR_RUN_PATH)
def setup_config():
def print_config():
global config
st = '---------CONFIG--------\n'
for k in config.keys():
st += k+':'+str(config.get(k))+'\n'
return st
conf = hparams.config
if conf is not None:
conf_func = globals()[conf]
try:
conf_func()
except:
logger.error(f"Config {conf} not defined")
logger.info(print_config())
setup_config()
global config
# parameters used in the baseline (read from main.py and solver.py) # parameters used in the baseline (read from main.py and solver.py)
# n_epochs = 500 # n_epochs = 500
# lr = 1e-4 # lr = 1e-4
...@@ -32,13 +64,13 @@ def run(): ...@@ -32,13 +64,13 @@ def run():
if USE_GPU: if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp', trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=100, train_percent_check=1.0, experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False, fast_dev_run=False,
early_stop_callback=early_stop, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback checkpoint_callback=checkpoint_callback
) )
else: else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1, trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=True) fast_dev_run=True)
model = Network(num_class=56) # TODO num_class model = Network(num_class=56) # TODO num_class
...@@ -50,4 +82,14 @@ def run(): ...@@ -50,4 +82,14 @@ def run():
if __name__=='__main__': if __name__=='__main__':
run() parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
\ No newline at end of file parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--config', type=str, help='config function to run')
#TODO : multiple arguments for --config using nargs='+' is not working with the test_tube
# implementation of argument parser
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
\ No newline at end of file
...@@ -113,10 +113,11 @@ class CNN(pl.LightningModule): ...@@ -113,10 +113,11 @@ class CNN(pl.LightningModule):
y_hat = self.forward_full_song(x, y) y_hat = self.forward_full_song(x, y)
y = y.float() y = y.float()
y_hat = y_hat.float() y_hat = y_hat.float()
decisions = y_hat.t().cpu() > 0.5
decisions = decisions.type(torch.float)
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu()) rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu())
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu()) prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu())
# _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), y_hat.t().cpu()) _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), decisions, average='micro')
fscore = 0.
return {'val_loss': self.my_loss(y_hat, y), return {'val_loss': self.my_loss(y_hat, y),
'val_rocauc':rocauc, 'val_rocauc':rocauc,
'val_prauc':prauc, 'val_prauc':prauc,
...@@ -134,6 +135,12 @@ class CNN(pl.LightningModule): ...@@ -134,6 +135,12 @@ class CNN(pl.LightningModule):
def validation_end(self, outputs): def validation_end(self, outputs):
return validation_end(outputs) return validation_end(outputs)
def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
return test_end(outputs)
def configure_optimizers(self): def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
...@@ -150,6 +157,6 @@ class CNN(pl.LightningModule): ...@@ -150,6 +157,6 @@ class CNN(pl.LightningModule):
return test_dataloader() return test_dataloader()
@staticmethod @staticmethod
def add_model_specific_args(parent_parser, root_dir): def add_model_specific_args(parent_parser):
return parent_parser return parent_parser
pass pass
\ No newline at end of file
...@@ -53,7 +53,8 @@ elif hostname == 'shreyan-HP': # Laptop Shreyan ...@@ -53,7 +53,8 @@ elif hostname == 'shreyan-HP': # Laptop Shreyan
USE_GPU = False USE_GPU = False
else: else:
PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo' PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo'
PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms') # PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
USE_GPU = False USE_GPU = False
if username == 'verena': if username == 'verena':
...@@ -69,43 +70,52 @@ TRAINED_MODELS_PATH = '' ...@@ -69,43 +70,52 @@ TRAINED_MODELS_PATH = ''
# run name # run name
def make_run_name(suffix=''): def make_run_name(suffix=''):
assert ' ' not in suffix # assert ' ' not in suffix
hash = hashlib.sha1() hash = hashlib.sha1()
hash.update(str(time.time()).encode('utf-8')) hash.update(str(time.time()).encode('utf-8'))
run_hash = hash.hexdigest()[:5] run_hash = hash.hexdigest()[:5]
name = run_hash + suffix name = run_hash + f' - {suffix}'
return name return name
curr_run_name = make_run_name() CURR_RUN_PATH = ''
CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name)
if not os.path.isdir(CURR_RUN_PATH):
os.mkdir(CURR_RUN_PATH)
# SET UP LOGGING =============================================
filelog = logging.getLogger() filelog = logging.getLogger()
streamlog = logging.getLogger() streamlog = logging.getLogger()
logger = logging.getLogger() logger = logging.getLogger()
fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log'))
sh = logging.StreamHandler() def init_experiment(comment='', name=None):
formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s') global CURR_RUN_PATH
fh.setFormatter(formatter) global filelog, streamlog, logger
sh.setFormatter(formatter) if name is None:
curr_run_name = make_run_name(comment)
# filelog logs only to file else:
filelog.addHandler(fh) curr_run_name = name
filelog.setLevel(logging.INFO) CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name)
# streamlog logs only to terminal if not os.path.isdir(CURR_RUN_PATH):
streamlog.addHandler(sh) os.mkdir(CURR_RUN_PATH)
streamlog.setLevel(logging.INFO)
# SET UP LOGGING =============================================
# logger logs to both file and terminal fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log'))
logger.addHandler(fh) sh = logging.StreamHandler()
logger.addHandler(sh) formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s')
logger.setLevel(logging.DEBUG) fh.setFormatter(formatter)
sh.setFormatter(formatter)
# ============================================
# filelog logs only to file
filelog.addHandler(fh)
filelog.setLevel(logging.INFO)
# streamlog logs only to terminal
streamlog.addHandler(sh)
streamlog.setLevel(logging.INFO)
# logger logs to both file and terminal
logger.addHandler(fh)
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
# ============================================
def write_to_file(data, path): def write_to_file(data, path):
# not fully implemented. unused function as of now. # not fully implemented. unused function as of now.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment