Commit 8e66a220 authored by Verena Praher's avatar Verena Praher
Browse files

merge

parents a54119c7 78ef5ebb
from utils import CURR_RUN_PATH, USE_GPU, logger
from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer
from test_tube import Experiment
from test_tube import Experiment, HyperOptArgumentParser
from models.baseline import CNN as Network
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import os
def run():
config = {
'epochs':1
}
def epochs_500():
global config
config['epochs'] = 500
def run(hparams):
init_experiment(comment=hparams.experiment_name)
from utils import CURR_RUN_PATH, logger # import these after init_experiment
logger.info(CURR_RUN_PATH)
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(save_dir=CURR_RUN_PATH)
def setup_config():
def print_config():
global config
st = '---------CONFIG--------\n'
for k in config.keys():
st += k+':'+str(config.get(k))+'\n'
return st
conf = hparams.config
if conf is not None:
conf_func = globals()[conf]
try:
conf_func()
except:
logger.error(f"Config {conf} not defined")
logger.info(print_config())
setup_config()
global config
# parameters used in the baseline (read from main.py and solver.py)
# n_epochs = 500
# lr = 1e-4
......@@ -32,13 +64,13 @@ def run():
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=100, train_percent_check=1.0,
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
)
)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=True)
model = Network(num_class=56) # TODO num_class
......@@ -50,4 +82,14 @@ def run():
if __name__=='__main__':
run()
\ No newline at end of file
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--config', type=str, help='config function to run')
#TODO : multiple arguments for --config using nargs='+' is not working with the test_tube
# implementation of argument parser
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
\ No newline at end of file
......@@ -113,10 +113,11 @@ class CNN(pl.LightningModule):
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
decisions = y_hat.t().cpu() > 0.5
decisions = decisions.type(torch.float)
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu())
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu())
# _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), y_hat.t().cpu())
fscore = 0.
_, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), decisions, average='micro')
return {'val_loss': self.my_loss(y_hat, y),
'val_rocauc':rocauc,
'val_prauc':prauc,
......@@ -134,6 +135,12 @@ class CNN(pl.LightningModule):
def validation_end(self, outputs):
return validation_end(outputs)
def test_step(self, data_batch, batch_nb):
return test_step(self, data_batch, batch_nb)
def test_end(self, outputs):
return test_end(outputs)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=1e-4)] # from their code
......@@ -150,6 +157,6 @@ class CNN(pl.LightningModule):
return test_dataloader()
@staticmethod
def add_model_specific_args(parent_parser, root_dir):
def add_model_specific_args(parent_parser):
return parent_parser
pass
\ No newline at end of file
......@@ -53,7 +53,8 @@ elif hostname == 'shreyan-HP': # Laptop Shreyan
USE_GPU = False
else:
PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo'
PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
# PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
USE_GPU = False
if username == 'verena':
......@@ -69,43 +70,52 @@ TRAINED_MODELS_PATH = ''
# run name
def make_run_name(suffix=''):
assert ' ' not in suffix
# assert ' ' not in suffix
hash = hashlib.sha1()
hash.update(str(time.time()).encode('utf-8'))
run_hash = hash.hexdigest()[:5]
name = run_hash + suffix
name = run_hash + f' - {suffix}'
return name
curr_run_name = make_run_name()
CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name)
if not os.path.isdir(CURR_RUN_PATH):
os.mkdir(CURR_RUN_PATH)
# SET UP LOGGING =============================================
CURR_RUN_PATH = ''
filelog = logging.getLogger()
streamlog = logging.getLogger()
logger = logging.getLogger()
fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log'))
sh = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s')
fh.setFormatter(formatter)
sh.setFormatter(formatter)
# filelog logs only to file
filelog.addHandler(fh)
filelog.setLevel(logging.INFO)
# streamlog logs only to terminal
streamlog.addHandler(sh)
streamlog.setLevel(logging.INFO)
# logger logs to both file and terminal
logger.addHandler(fh)
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
# ============================================
def init_experiment(comment='', name=None):
global CURR_RUN_PATH
global filelog, streamlog, logger
if name is None:
curr_run_name = make_run_name(comment)
else:
curr_run_name = name
CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name)
if not os.path.isdir(CURR_RUN_PATH):
os.mkdir(CURR_RUN_PATH)
# SET UP LOGGING =============================================
fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log'))
sh = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s')
fh.setFormatter(formatter)
sh.setFormatter(formatter)
# filelog logs only to file
filelog.addHandler(fh)
filelog.setLevel(logging.INFO)
# streamlog logs only to terminal
streamlog.addHandler(sh)
streamlog.setLevel(logging.INFO)
# logger logs to both file and terminal
logger.addHandler(fh)
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
# ============================================
def write_to_file(data, path):
# not fully implemented. unused function as of now.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment