Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shreyan Chowdhury
moodwalk
Commits
ddf6c7cc
Commit
ddf6c7cc
authored
Sep 14, 2019
by
Shreyan Chowdhury
Committed by
Shreyan Chowdhury
Sep 14, 2019
Browse files
add paths for shreyan-HP, minor refactoring
parent
bc052327
Changes
12
Hide whitespace changes
Inline
Side-by-side
datasets/collate.py
View file @
ddf6c7cc
import
torch
class
PadSequence
:
def
__call__
(
self
,
batch
):
# print("PadSequence is called")
...
...
datasets/midlevel.py
View file @
ddf6c7cc
...
...
@@ -17,7 +17,7 @@ def sample_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache
=
{}
def
t2
_parse_labels
(
csvf
):
def
midlevel
_parse_labels
(
csvf
):
global
t2_parse_labels_cache
if
t2_parse_labels_cache
.
get
(
csvf
)
is
not
None
:
return
t2_parse_labels_cache
.
get
(
csvf
)
...
...
@@ -91,7 +91,7 @@ def df_get_midlevel_set(name, midlevel_files_csv, audio_path, cache_x_name):
print
(
"loading dataset from '{}'"
.
format
(
name
))
def
getdatset
():
files
,
labels
=
t2
_parse_labels
(
midlevel_files_csv
)
files
,
labels
=
midlevel
_parse_labels
(
midlevel_files_csv
)
return
AudioPreprocessDataset
(
files
,
labels
,
label_encoder
,
audio_path
,
audio_processor
)
df_trset
=
H5FCachedDataset
(
getdatset
,
name
,
slicing_function
=
sample_slicing_function
,
...
...
datasets/mtgjamendo.py
View file @
ddf6c7cc
...
...
@@ -31,7 +31,7 @@ def full_song_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache
=
{}
def
t2
_parse_labels
(
csvf
):
def
mtgjamendo
_parse_labels
(
csvf
):
global
t2_parse_labels_cache
if
t2_parse_labels_cache
.
get
(
csvf
)
is
not
None
:
return
t2_parse_labels_cache
.
get
(
csvf
)
...
...
@@ -111,7 +111,7 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N
print
(
"loading dataset from '{}'"
.
format
(
name
))
def
getdatset
():
files
,
labels
,
label_encoder
=
t2
_parse_labels
(
mtg_files_csv
)
files
,
labels
,
label_encoder
=
mtgjamendo
_parse_labels
(
mtg_files_csv
)
return
AudioPreprocessDataset
(
files
,
labels
,
label_encoder
,
audio_path
,
audio_processor
)
if
slicing_func
is
None
:
...
...
datasets/shared_data_utils.py
View file @
ddf6c7cc
...
...
@@ -64,5 +64,21 @@ if hostname == 'shreyan-All-Series':
path_mtgjamendo_annotations_test
=
'/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir
=
'/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if
hostname
not
in
[
'rechenknecht1.cp.jku.at'
,
'rechenknecht2.cp.jku.at'
,
'rechenknecht3.cp.jku.at'
,
'shreyan-All-Series'
]:
raise
Exception
(
f
"Paths not defined for
{
hostname
}
"
)
\ No newline at end of file
if
hostname
==
'shreyan-HP'
:
rk
=
'rk2'
path_data_cache
=
'/mnt/2tb/datasets/data_caches'
# midlevel
path_midlevel_annotations_dir
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/midlevel/metadata_annotations'
path_midlevel_annotations
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/midlevel/metadata_annotations/annotations.csv'
path_midlevel_audio_dir
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/midlevel/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv'
path_mtgjamendo_annotations_test
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if
hostname
not
in
[
'rechenknecht1.cp.jku.at'
,
'rechenknecht2.cp.jku.at'
,
'rechenknecht3.cp.jku.at'
,
'shreyan-All-Series'
,
'shreyan-HP'
]:
raise
Exception
(
f
"Paths not defined for
{
hostname
}
"
)
experiments/experiment_crnn.py
View file @
ddf6c7cc
...
...
@@ -6,14 +6,15 @@ from models.crnn import CRNN as Network
import
os
model_config
=
{
'data_source'
:
'mtgjamendo'
,
'validation_metrics'
:[
'rocauc'
,
'prauc'
],
'test_metrics'
:[
'rocauc'
,
'prauc'
]
'data_source'
:
'mtgjamendo'
,
'validation_metrics'
:
[
'rocauc'
,
'prauc'
],
'test_metrics'
:
[
'rocauc'
,
'prauc'
]
}
initialized
=
False
# TODO: Find a better way to do this
initialized
=
False
# TODO: Find a better way to do this
trial_counter
=
0
def
run
(
hparams
):
global
initialized
,
trial_counter
trial_counter
+=
1
...
...
@@ -56,7 +57,7 @@ def run(hparams):
train_percent_check
=
hparams
.
train_percent
,
fast_dev_run
=
False
,
early_stop_callback
=
early_stop
,
checkpoint_callback
=
checkpoint_callback
,
nb_sanity_val_steps
=
0
)
# don't run sanity validation run
nb_sanity_val_steps
=
0
)
# don't run sanity validation run
else
:
trainer
=
Trainer
(
experiment
=
exp
,
max_nb_epochs
=
1
,
train_percent_check
=
0.1
,
fast_dev_run
=
True
)
...
...
@@ -68,7 +69,7 @@ def run(hparams):
trainer
.
test
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parent_parser
=
HyperOptArgumentParser
(
strategy
=
'grid_search'
,
add_help
=
False
)
parent_parser
.
add_argument
(
'--experiment_name'
,
type
=
str
,
default
=
'pt_lightning_exp_a'
,
help
=
'test tube exp name'
)
...
...
@@ -76,14 +77,14 @@ if __name__=='__main__':
default
=
1.0
,
help
=
'how much train data to use'
)
parent_parser
.
add_argument
(
'--max_epochs'
,
type
=
int
,
default
=
10
,
help
=
'maximum number of epochs'
)
#parent_parser.add_argument('--gpus', type=list, default=[0,1],
#
parent_parser.add_argument('--gpus', type=list, default=[0,1],
# help='how many gpus to use in the node.'
# ' value -1 uses all the gpus on the node')
parser
=
Network
.
add_model_specific_args
(
parent_parser
)
hyperparams
=
parser
.
parse_args
()
# run(hyperparams)
#gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5)
#
gpus = ['cuda:0', 'cuda:1']
#
hyperparams.optimize_parallel_gpu(run, gpus, 5)
# run(hyperparams)
for
hparam_trial
in
hyperparams
.
trials
(
20
):
run
(
hparam_trial
)
models/__init__.py
deleted
100644 → 0
View file @
bc052327
models/shared_stuff.py
View file @
ddf6c7cc
...
...
@@ -15,10 +15,10 @@ from datasets.collate import PadSequence
# example config dict
base_model_config
=
{
'data_source'
:
'mtgjamendo'
,
'training_metrics'
:[
'loss'
],
'validation_metrics'
:[
'loss'
,
'prauc'
,
'rocauc'
],
'test_metrics'
:[
'loss'
,
'prauc'
,
'rocauc'
]
'data_source'
:
'mtgjamendo'
,
'training_metrics'
:
[
'loss'
],
'validation_metrics'
:
[
'loss'
,
'prauc'
,
'rocauc'
],
'test_metrics'
:
[
'loss'
,
'prauc'
,
'rocauc'
]
}
...
...
@@ -42,7 +42,7 @@ class BasePtlModel(pl.LightningModule):
self
.
validation_metrics
=
config
.
get
(
'validation_metrics'
)
self
.
test_metrics
=
config
.
get
(
'test_metrics'
)
if
self
.
data_source
==
'midlevel'
:
if
self
.
data_source
==
'midlevel'
:
dataset
,
dataset_length
=
df_get_midlevel_set
(
'midlevel'
,
path_midlevel_annotations
,
path_midlevel_audio_dir
,
...
...
@@ -177,27 +177,26 @@ class BasePtlModel(pl.LightningModule):
@
pl
.
data_loader
def
tng_dataloader
(
self
):
if
self
.
data_source
==
'mtgjamendo'
:
if
self
.
data_source
==
'mtgjamendo'
:
dataset
=
df_get_mtg_set
(
'mtgjamendo'
,
path_mtgjamendo_annotations_train
,
path_mtgjamendo_audio_dir
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
slice_len
=
self
.
input_size
)
elif
self
.
data_source
==
'midlevel'
:
elif
self
.
data_source
==
'midlevel'
:
dataset
=
self
.
midlevel_trainset
else
:
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
if
self
.
slicing_mode
==
'full'
:
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
,
collate_fn
=
PadSequence
())
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
else
:
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
@
pl
.
data_loader
def
val_dataloader
(
self
):
...
...
@@ -207,10 +206,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
slice_len
=
self
.
input_size
)
elif
self
.
data_source
==
'midlevel'
:
dataset
=
self
.
midlevel_valset
else
:
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
...
...
@@ -219,10 +216,10 @@ class BasePtlModel(pl.LightningModule):
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
,
collate_fn
=
PadSequence
())
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
else
:
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
@
pl
.
data_loader
def
test_dataloader
(
self
):
...
...
@@ -232,10 +229,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
slice_len
=
self
.
input_size
)
elif
self
.
data_source
==
'midlevel'
:
dataset
=
self
.
midlevel_testset
else
:
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
...
...
@@ -244,7 +239,7 @@ class BasePtlModel(pl.LightningModule):
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
,
collate_fn
=
PadSequence
())
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
else
:
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
results/tag_freq_test.png
deleted
100644 → 0
View file @
bc052327
193 KB
results/tag_freq_train.png
deleted
100644 → 0
View file @
bc052327
213 KB
results/tag_freq_validation.png
deleted
100644 → 0
View file @
bc052327
206 KB
tagslist.npy
deleted
100644 → 0
View file @
bc052327
File deleted
utils.py
View file @
ddf6c7cc
...
...
@@ -70,6 +70,8 @@ elif hostname == 'verena-830g5': # Laptop Verena
USE_GPU
=
False
elif
hostname
==
'shreyan-HP'
:
# Laptop Shreyan
USE_GPU
=
False
PATH_DATA_CACHE
=
'/home/shreyan/mounts/home@rk2/shared/kofta_cached_datasets'
else
:
# PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
PATH_DATA_CACHE
=
'/mnt/2tb/datasets/data_caches'
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment