Commit cc057fed authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

add experiment_name functionality, fscore with 0.5 threshold

parent 22b36bc3
from utils import CURR_RUN_PATH, USE_GPU, logger
from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer
from test_tube import Experiment
from test_tube import Experiment, HyperOptArgumentParser
from models.baseline import CNN as Network
def run():
def run(hparams):
init_experiment(comment=hparams.experiment_name)
from utils import CURR_RUN_PATH, logger # import these after init_experiment
logger.info(CURR_RUN_PATH)
exp = Experiment(save_dir=CURR_RUN_PATH)
......@@ -16,10 +20,10 @@ def run():
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend='ddp',
experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
experiment=exp, max_nb_epochs=1, train_percent_check=1.0,
fast_dev_run=False)
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=True)
model = Network(num_class=56) # TODO num_class
......@@ -31,4 +35,9 @@ def run():
if __name__=='__main__':
run()
\ No newline at end of file
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
run(hyperparams)
\ No newline at end of file
......@@ -119,22 +119,48 @@ class CNN(pl.LightningModule):
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
decisions = y_hat.t().cpu() > 0.5
decisions = decisions.type(torch.float)
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu())
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu())
# _, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), y_hat.t().cpu())
fscore = 0.
_, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), decisions, average='micro')
return {'val_loss': self.my_loss(y_hat, y),
'rocauc':rocauc,
'prauc':prauc,
'fscore':fscore}
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
# avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_auc = torch.stack([torch.tensor([x['rocauc']]) for x in outputs]).mean()
avg_prauc = torch.stack([torch.tensor([x['prauc']]) for x in outputs]).mean()
avg_fscore = torch.stack([torch.tensor([x['fscore']]) for x in outputs]).mean()
return {'val_loss':avg_loss,
'rocauc':avg_auc,
return {'rocauc':avg_auc,
'prauc':avg_prauc,
'fscore':avg_fscore}
def test_step(self, data_batch, batch_nb):
# print("data_batch", data_batch)
x, _, y = data_batch
# print("x", x)
# print("y", y)
y_hat = self.forward_full_song(x, y)
y = y.float()
y_hat = y_hat.float()
decisions = y_hat.t().cpu() > 0.5
decisions = decisions.type(torch.float)
rocauc = metrics.roc_auc_score(y.t().cpu(), y_hat.t().cpu())
prauc = metrics.average_precision_score(y.t().cpu(), y_hat.t().cpu())
_, _, fscore, _ = metrics.precision_recall_fscore_support(y.t().cpu(), decisions, average='micro')
return {'rocauc':rocauc,
'prauc':prauc,
'fscore':fscore}
def test_end(self, outputs):
# avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_auc = torch.stack([torch.tensor([x['rocauc']]) for x in outputs]).mean()
avg_prauc = torch.stack([torch.tensor([x['prauc']]) for x in outputs]).mean()
avg_fscore = torch.stack([torch.tensor([x['fscore']]) for x in outputs]).mean()
return {'rocauc':avg_auc,
'prauc':avg_prauc,
'fscore':avg_fscore}
......@@ -169,6 +195,6 @@ class CNN(pl.LightningModule):
shuffle=True)
@staticmethod
def add_model_specific_args(parent_parser, root_dir):
def add_model_specific_args(parent_parser):
return parent_parser
pass
\ No newline at end of file
......@@ -48,7 +48,8 @@ elif hostname == 'shreyan-HP': # Laptop Shreyan
USE_GPU = False
else:
PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo'
PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
# PATH_DATA_CACHE = os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms')
PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
USE_GPU = False
if username == 'verena':
......@@ -64,43 +65,52 @@ TRAINED_MODELS_PATH = ''
# run name
def make_run_name(suffix=''):
assert ' ' not in suffix
# assert ' ' not in suffix
hash = hashlib.sha1()
hash.update(str(time.time()).encode('utf-8'))
run_hash = hash.hexdigest()[:5]
name = run_hash + suffix
name = run_hash + f' - {suffix}'
return name
curr_run_name = make_run_name()
CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name)
if not os.path.isdir(CURR_RUN_PATH):
os.mkdir(CURR_RUN_PATH)
# SET UP LOGGING =============================================
CURR_RUN_PATH = ''
filelog = logging.getLogger()
streamlog = logging.getLogger()
logger = logging.getLogger()
fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log'))
sh = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s')
fh.setFormatter(formatter)
sh.setFormatter(formatter)
# filelog logs only to file
filelog.addHandler(fh)
filelog.setLevel(logging.INFO)
# streamlog logs only to terminal
streamlog.addHandler(sh)
streamlog.setLevel(logging.INFO)
# logger logs to both file and terminal
logger.addHandler(fh)
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
# ============================================
def init_experiment(comment='', name=None):
global CURR_RUN_PATH
global filelog, streamlog, logger
if name is None:
curr_run_name = make_run_name(comment)
else:
curr_run_name = name
CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name)
if not os.path.isdir(CURR_RUN_PATH):
os.mkdir(CURR_RUN_PATH)
# SET UP LOGGING =============================================
fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log'))
sh = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s')
fh.setFormatter(formatter)
sh.setFormatter(formatter)
# filelog logs only to file
filelog.addHandler(fh)
filelog.setLevel(logging.INFO)
# streamlog logs only to terminal
streamlog.addHandler(sh)
streamlog.setLevel(logging.INFO)
# logger logs to both file and terminal
logger.addHandler(fh)
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
# ============================================
def write_to_file(data, path):
# not fully implemented. unused function as of now.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment