Commit b4aa9c39 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury

minor refactoring

parent 4869d2e5
......@@ -4,7 +4,7 @@ import torch
import librosa
import numpy as np
import pandas as pd
from utils import PATH_DATA_CACHE
from datasets.shared_data_utils import path_data_cache
slice_length = 512 #TODO: Find a better way
n_mels = 256
......@@ -163,7 +163,7 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N
df_trset = H5FCachedDataset(getdatset, name, slicing_function=slicing_func,
x_name=cache_x_name,
cache_path=PATH_DATA_CACHE,
cache_path=path_data_cache,
augment_options=augment_options
)
......
......@@ -33,7 +33,7 @@ def run(hparams):
logger.info(hparams)
exp = Experiment(name=trial_name, save_dir=CURR_RUN_PATH)
# exp.tag(hparams)
exp.tag(hparams)
# callbacks
early_stop = EarlyStopping(
......@@ -52,7 +52,7 @@ def run(hparams):
)
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend=None,
trainer = Trainer(gpus=[hparams.gpu], distributed_backend=None,
experiment=exp, max_nb_epochs=hparams.max_epochs,
train_percent_check=hparams.train_percent,
fast_dev_run=False, early_stop_callback=early_stop,
......@@ -70,6 +70,7 @@ def run(hparams):
trainer.fit(model)
test_metrics = trainer.test()
logger.info(test_metrics)
exp.log(test_metrics)
except KeyboardInterrupt:
logger.info("Run interrupted")
except Exception as e:
......@@ -93,6 +94,8 @@ def run(hparams):
if __name__ == '__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--gpu', type=int,
default=0, help='which gpu to use')
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--train_percent', type=float,
......
......@@ -71,7 +71,7 @@ def train_mtgjamendo(hparams, midlevel_chkpt_dir):
exp = Experiment(name='mtg', save_dir=CURR_RUN_PATH)
# mtg_configs()
logger.info(f"Loading model from {midlevel_chkpt_dir}")
model = Network(model_config, hparams, num_targets=7, dataset='mtgjamendo', on_gpu=USE_GPU, load_from=midlevel_chkpt_dir)
model = Network(model_config, hparams, num_targets=7, source_dataset='mtgjamendo', on_gpu=USE_GPU, load_from=midlevel_chkpt_dir)
logger.info(f"Loaded model successfully")
early_stop = EarlyStopping(
......
......@@ -80,7 +80,7 @@ class CRNN(BasePtlModel):
def rnn_forward(x):
x = x.squeeze()
x = self.gru1(x)[1][1] # TODO: Check if this is correct
x = self.gru1(x)[0][-1] # TODO: Check if this is correct
x = self.dropout(x)
logit = nn.Sigmoid()(self.dense(x))
return logit
......
......@@ -128,11 +128,4 @@ def hdf_cache_specs(source_root, destination_root, annotations_file='all', outpu
if __name__=='__main__':
# pass
hdf_cache_specs(PATH_MELSPEC_DOWNLOADED, os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms'),
annotations_file=os.path.join(PATH_ANNOTATIONS, 'train_processed.tsv'), output_filename='train.h5')
hdf_cache_specs(PATH_MELSPEC_DOWNLOADED, os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms'),
annotations_file=os.path.join(PATH_ANNOTATIONS, 'validation_processed.tsv'), output_filename='val.h5')
hdf_cache_specs(PATH_MELSPEC_DOWNLOADED, os.path.join(PATH_DATA_ROOT, 'HDF5Cache_spectrograms'),
annotations_file=os.path.join(PATH_ANNOTATIONS, 'test_processed.tsv'), output_filename='test.h5')
# pass
\ No newline at end of file
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment