Commit cc057fed authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

add experiment_name functionality, fscore with 0.5 threshold

parent 22b36bc3
from utils import CURR_RUN_PATH, USE_GPU, logger
from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from test_tube import Experiment from test_tube import Experiment, HyperOptArgumentParser
from models.baseline import CNN as Network from models.baseline import CNN as Network
def run():
def run(hparams):
init_experiment(comment=hparams.experiment_name)
from utils import CURR_RUN_PATH, logger # import these after init_experiment
logger.info(CURR_RUN_PATH) logger.info(CURR_RUN_PATH)
exp = Experiment(save_dir=CURR_RUN_PATH) exp = Experiment(save_dir=CURR_RUN_PATH)
...@@ -16,10 +20,10 @@ def run(): ...@@ -16,10 +20,10 @@ def run():
if USE_GPU: if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp', trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=1, train_percent_check=0.1, experiment=exp, max_nb_epochs=1, train_percent_check=1.0,
fast_dev_run=False) fast_dev_run=False)
else: else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1, trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=True) fast_dev_run=True)
model = Network(num_class=56) # TODO num_class model = Network(num_class=56) # TODO num_class
...@@ -31,4 +35,9 @@ def run(): ...@@ -31,4 +35,9 @@ def run():
if __name__=='__main__': if __name__=='__main__':
run() parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
\ No newline at end of file parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
\ No newline at end of file
...@@ -119,22 +119,48 @@ class CNN(pl.LightningModule): ...@@ -119,22 +119,48 @@ class CNN(pl.LightningModule):
y_hat = self.forward_full_song(x, y) y_hat = self.forward_full_song(x, y)
y = y.float() y = y.float()
y_hat = y_hat.float() y_hat = y_hat.float()
decisions = y_hat.t().cpu() > 0.5
decisions = decisions.type(torch.float)
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu()) rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu())
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu()) prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu())
# _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), y_hat.t().cpu()) _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), decisions, average='micro')
fscore = 0.
return {'val_loss': self.my_loss(y_hat, y), return {'val_loss': self.my_loss(y_hat, y),
'rocauc':rocauc, 'rocauc':rocauc,
'prauc':prauc, 'prauc':prauc,
'fscore':fscore} 'fscore':fscore}
def validation_end(self, outputs): def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() # avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_auc = torch.stack([torch.tensor([x['rocauc']]) for x in outputs]).mean() avg_auc = torch.stack([torch.tensor([x['rocauc']]) for x in outputs]).mean()
avg_prauc = torch.stack([torch.tensor([x['prauc']]) for x in outputs]).mean() avg_prauc = torch.stack([torch.tensor([x['prauc']]) for x in outputs]).mean()
avg_fscore = torch.stack([torch.tensor([x['fscore']]) for x in outputs]).mean() avg_fscore = torch.stack([torch.tensor([x['fscore']]) for x in outputs]).mean()
return {'val_loss':avg_loss, return {'rocauc':avg_auc,
'rocauc':avg_auc, 'prauc':avg_prauc,
'fscore':avg_fscore}
def test_step(self, data_batch, batch_nb):
# print("data_batch", data_batch)
x, _, y = data_batch
# print("x", x)
# print("y", y)
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
decisions = y_hat.t().cpu() > 0.5
decisions = decisions.type(torch.float)
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu())
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu())
_, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), decisions, average='micro')
return {'rocauc':rocauc,
'prauc':prauc,
'fscore':fscore}
def test_end(self, outputs):
# avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_auc = torch.stack([torch.tensor([x['rocauc']]) for x in outputs]).mean()
avg_prauc = torch.stack([torch.tensor([x['prauc']]) for x in outputs]).mean()
avg_fscore = torch.stack([torch.tensor([x['fscore']]) for x in outputs]).mean()
return {'rocauc':avg_auc,
'prauc':avg_prauc, 'prauc':avg_prauc,
'fscore':avg_fscore} 'fscore':avg_fscore}
...@@ -169,6 +195,6 @@ class CNN(pl.LightningModule): ...@@ -169,6 +195,6 @@ class CNN(pl.LightningModule):
shuffle=True) shuffle=True)
@staticmethod @staticmethod
def add_model_specific_args(parent_parser, root_dir): def add_model_specific_args(parent_parser):
return parent_parser return parent_parser
pass pass
\ No newline at end of file
...@@ -48,7 +48,8 @@ elif hostname == 'shreyan-HP': # Laptop Shreyan ...@@ -48,7 +48,8 @@ elif hostname == 'shreyan-HP': # Laptop Shreyan
USE_GPU = False USE_GPU = False
else: else:
PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo' PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo'
PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms') # PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
USE_GPU = False USE_GPU = False
if username == 'verena': if username == 'verena':
...@@ -64,43 +65,52 @@ TRAINED_MODELS_PATH = '' ...@@ -64,43 +65,52 @@ TRAINED_MODELS_PATH = ''
# run name # run name
def make_run_name(suffix=''): def make_run_name(suffix=''):
assert ' ' not in suffix # assert ' ' not in suffix
hash = hashlib.sha1() hash = hashlib.sha1()
hash.update(str(time.time()).encode('utf-8')) hash.update(str(time.time()).encode('utf-8'))
run_hash = hash.hexdigest()[:5] run_hash = hash.hexdigest()[:5]
name = run_hash + suffix name = run_hash + f' - {suffix}'
return name return name
curr_run_name = make_run_name() CURR_RUN_PATH = ''
CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name)
if not os.path.isdir(CURR_RUN_PATH):
os.mkdir(CURR_RUN_PATH)
# SET UP LOGGING =============================================
filelog = logging.getLogger() filelog = logging.getLogger()
streamlog = logging.getLogger() streamlog = logging.getLogger()
logger = logging.getLogger() logger = logging.getLogger()
fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log'))
sh = logging.StreamHandler() def init_experiment(comment='', name=None):
formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s') global CURR_RUN_PATH
fh.setFormatter(formatter) global filelog, streamlog, logger
sh.setFormatter(formatter) if name is None:
curr_run_name = make_run_name(comment)
# filelog logs only to file else:
filelog.addHandler(fh) curr_run_name = name
filelog.setLevel(logging.INFO) CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name)
# streamlog logs only to terminal if not os.path.isdir(CURR_RUN_PATH):
streamlog.addHandler(sh) os.mkdir(CURR_RUN_PATH)
streamlog.setLevel(logging.INFO)
# SET UP LOGGING =============================================
# logger logs to both file and terminal fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log'))
logger.addHandler(fh) sh = logging.StreamHandler()
logger.addHandler(sh) formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s')
logger.setLevel(logging.DEBUG) fh.setFormatter(formatter)
sh.setFormatter(formatter)
# ============================================
# filelog logs only to file
filelog.addHandler(fh)
filelog.setLevel(logging.INFO)
# streamlog logs only to terminal
streamlog.addHandler(sh)
streamlog.setLevel(logging.INFO)
# logger logs to both file and terminal
logger.addHandler(fh)
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
# ============================================
def write_to_file(data, path): def write_to_file(data, path):
# not fully implemented. unused function as of now. # not fully implemented. unused function as of now.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment