Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shreyan Chowdhury
moodwalk
Commits
ddf6c7cc
Commit
ddf6c7cc
authored
Sep 14, 2019
by
Shreyan Chowdhury
Committed by
Shreyan Chowdhury
Sep 14, 2019
Browse files
add paths for shreyan-HP, minor refactoring
parent
bc052327
Changes
12
Show whitespace changes
Inline
Side-by-side
datasets/collate.py
View file @
ddf6c7cc
import
torch
import
torch
class
PadSequence
:
class
PadSequence
:
def
__call__
(
self
,
batch
):
def
__call__
(
self
,
batch
):
# print("PadSequence is called")
# print("PadSequence is called")
...
...
datasets/midlevel.py
View file @
ddf6c7cc
...
@@ -17,7 +17,7 @@ def sample_slicing_function(h5data, idx, xlen):
...
@@ -17,7 +17,7 @@ def sample_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache
=
{}
t2_parse_labels_cache
=
{}
def
t2
_parse_labels
(
csvf
):
def
midlevel
_parse_labels
(
csvf
):
global
t2_parse_labels_cache
global
t2_parse_labels_cache
if
t2_parse_labels_cache
.
get
(
csvf
)
is
not
None
:
if
t2_parse_labels_cache
.
get
(
csvf
)
is
not
None
:
return
t2_parse_labels_cache
.
get
(
csvf
)
return
t2_parse_labels_cache
.
get
(
csvf
)
...
@@ -91,7 +91,7 @@ def df_get_midlevel_set(name, midlevel_files_csv, audio_path, cache_x_name):
...
@@ -91,7 +91,7 @@ def df_get_midlevel_set(name, midlevel_files_csv, audio_path, cache_x_name):
print
(
"loading dataset from '{}'"
.
format
(
name
))
print
(
"loading dataset from '{}'"
.
format
(
name
))
def
getdatset
():
def
getdatset
():
files
,
labels
=
t2
_parse_labels
(
midlevel_files_csv
)
files
,
labels
=
midlevel
_parse_labels
(
midlevel_files_csv
)
return
AudioPreprocessDataset
(
files
,
labels
,
label_encoder
,
audio_path
,
audio_processor
)
return
AudioPreprocessDataset
(
files
,
labels
,
label_encoder
,
audio_path
,
audio_processor
)
df_trset
=
H5FCachedDataset
(
getdatset
,
name
,
slicing_function
=
sample_slicing_function
,
df_trset
=
H5FCachedDataset
(
getdatset
,
name
,
slicing_function
=
sample_slicing_function
,
...
...
datasets/mtgjamendo.py
View file @
ddf6c7cc
...
@@ -31,7 +31,7 @@ def full_song_slicing_function(h5data, idx, xlen):
...
@@ -31,7 +31,7 @@ def full_song_slicing_function(h5data, idx, xlen):
t2_parse_labels_cache
=
{}
t2_parse_labels_cache
=
{}
def
t2
_parse_labels
(
csvf
):
def
mtgjamendo
_parse_labels
(
csvf
):
global
t2_parse_labels_cache
global
t2_parse_labels_cache
if
t2_parse_labels_cache
.
get
(
csvf
)
is
not
None
:
if
t2_parse_labels_cache
.
get
(
csvf
)
is
not
None
:
return
t2_parse_labels_cache
.
get
(
csvf
)
return
t2_parse_labels_cache
.
get
(
csvf
)
...
@@ -111,7 +111,7 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N
...
@@ -111,7 +111,7 @@ def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=N
print
(
"loading dataset from '{}'"
.
format
(
name
))
print
(
"loading dataset from '{}'"
.
format
(
name
))
def
getdatset
():
def
getdatset
():
files
,
labels
,
label_encoder
=
t2
_parse_labels
(
mtg_files_csv
)
files
,
labels
,
label_encoder
=
mtgjamendo
_parse_labels
(
mtg_files_csv
)
return
AudioPreprocessDataset
(
files
,
labels
,
label_encoder
,
audio_path
,
audio_processor
)
return
AudioPreprocessDataset
(
files
,
labels
,
label_encoder
,
audio_path
,
audio_processor
)
if
slicing_func
is
None
:
if
slicing_func
is
None
:
...
...
datasets/shared_data_utils.py
View file @
ddf6c7cc
...
@@ -64,5 +64,21 @@ if hostname == 'shreyan-All-Series':
...
@@ -64,5 +64,21 @@ if hostname == 'shreyan-All-Series':
path_mtgjamendo_annotations_test
=
'/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_annotations_test
=
'/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir
=
'/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_audio'
path_mtgjamendo_audio_dir
=
'/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if
hostname
not
in
[
'rechenknecht1.cp.jku.at'
,
'rechenknecht2.cp.jku.at'
,
'rechenknecht3.cp.jku.at'
,
'shreyan-All-Series'
]:
if
hostname
==
'shreyan-HP'
:
rk
=
'rk2'
path_data_cache
=
'/mnt/2tb/datasets/data_caches'
# midlevel
path_midlevel_annotations_dir
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/midlevel/metadata_annotations'
path_midlevel_annotations
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/midlevel/metadata_annotations/annotations.csv'
path_midlevel_audio_dir
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/midlevel/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv'
path_mtgjamendo_annotations_test
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir
=
f
'/home/shreyan/mounts/home@
{
rk
}
/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if
hostname
not
in
[
'rechenknecht1.cp.jku.at'
,
'rechenknecht2.cp.jku.at'
,
'rechenknecht3.cp.jku.at'
,
'shreyan-All-Series'
,
'shreyan-HP'
]:
raise
Exception
(
f
"Paths not defined for
{
hostname
}
"
)
raise
Exception
(
f
"Paths not defined for
{
hostname
}
"
)
experiments/experiment_crnn.py
View file @
ddf6c7cc
...
@@ -6,14 +6,15 @@ from models.crnn import CRNN as Network
...
@@ -6,14 +6,15 @@ from models.crnn import CRNN as Network
import
os
import
os
model_config
=
{
model_config
=
{
'data_source'
:
'mtgjamendo'
,
'data_source'
:
'mtgjamendo'
,
'validation_metrics'
:[
'rocauc'
,
'prauc'
],
'validation_metrics'
:
[
'rocauc'
,
'prauc'
],
'test_metrics'
:[
'rocauc'
,
'prauc'
]
'test_metrics'
:
[
'rocauc'
,
'prauc'
]
}
}
initialized
=
False
# TODO: Find a better way to do this
initialized
=
False
# TODO: Find a better way to do this
trial_counter
=
0
trial_counter
=
0
def
run
(
hparams
):
def
run
(
hparams
):
global
initialized
,
trial_counter
global
initialized
,
trial_counter
trial_counter
+=
1
trial_counter
+=
1
...
@@ -68,7 +69,7 @@ def run(hparams):
...
@@ -68,7 +69,7 @@ def run(hparams):
trainer
.
test
()
trainer
.
test
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parent_parser
=
HyperOptArgumentParser
(
strategy
=
'grid_search'
,
add_help
=
False
)
parent_parser
=
HyperOptArgumentParser
(
strategy
=
'grid_search'
,
add_help
=
False
)
parent_parser
.
add_argument
(
'--experiment_name'
,
type
=
str
,
parent_parser
.
add_argument
(
'--experiment_name'
,
type
=
str
,
default
=
'pt_lightning_exp_a'
,
help
=
'test tube exp name'
)
default
=
'pt_lightning_exp_a'
,
help
=
'test tube exp name'
)
...
@@ -76,14 +77,14 @@ if __name__=='__main__':
...
@@ -76,14 +77,14 @@ if __name__=='__main__':
default
=
1.0
,
help
=
'how much train data to use'
)
default
=
1.0
,
help
=
'how much train data to use'
)
parent_parser
.
add_argument
(
'--max_epochs'
,
type
=
int
,
parent_parser
.
add_argument
(
'--max_epochs'
,
type
=
int
,
default
=
10
,
help
=
'maximum number of epochs'
)
default
=
10
,
help
=
'maximum number of epochs'
)
#parent_parser.add_argument('--gpus', type=list, default=[0,1],
#
parent_parser.add_argument('--gpus', type=list, default=[0,1],
# help='how many gpus to use in the node.'
# help='how many gpus to use in the node.'
# ' value -1 uses all the gpus on the node')
# ' value -1 uses all the gpus on the node')
parser
=
Network
.
add_model_specific_args
(
parent_parser
)
parser
=
Network
.
add_model_specific_args
(
parent_parser
)
hyperparams
=
parser
.
parse_args
()
hyperparams
=
parser
.
parse_args
()
# run(hyperparams)
# run(hyperparams)
#gpus = ['cuda:0', 'cuda:1']
#
gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5)
#
hyperparams.optimize_parallel_gpu(run, gpus, 5)
# run(hyperparams)
# run(hyperparams)
for
hparam_trial
in
hyperparams
.
trials
(
20
):
for
hparam_trial
in
hyperparams
.
trials
(
20
):
run
(
hparam_trial
)
run
(
hparam_trial
)
models/__init__.py
deleted
100644 → 0
View file @
bc052327
models/shared_stuff.py
View file @
ddf6c7cc
...
@@ -15,10 +15,10 @@ from datasets.collate import PadSequence
...
@@ -15,10 +15,10 @@ from datasets.collate import PadSequence
# example config dict
# example config dict
base_model_config
=
{
base_model_config
=
{
'data_source'
:
'mtgjamendo'
,
'data_source'
:
'mtgjamendo'
,
'training_metrics'
:[
'loss'
],
'training_metrics'
:
[
'loss'
],
'validation_metrics'
:[
'loss'
,
'prauc'
,
'rocauc'
],
'validation_metrics'
:
[
'loss'
,
'prauc'
,
'rocauc'
],
'test_metrics'
:[
'loss'
,
'prauc'
,
'rocauc'
]
'test_metrics'
:
[
'loss'
,
'prauc'
,
'rocauc'
]
}
}
...
@@ -42,7 +42,7 @@ class BasePtlModel(pl.LightningModule):
...
@@ -42,7 +42,7 @@ class BasePtlModel(pl.LightningModule):
self
.
validation_metrics
=
config
.
get
(
'validation_metrics'
)
self
.
validation_metrics
=
config
.
get
(
'validation_metrics'
)
self
.
test_metrics
=
config
.
get
(
'test_metrics'
)
self
.
test_metrics
=
config
.
get
(
'test_metrics'
)
if
self
.
data_source
==
'midlevel'
:
if
self
.
data_source
==
'midlevel'
:
dataset
,
dataset_length
=
df_get_midlevel_set
(
'midlevel'
,
dataset
,
dataset_length
=
df_get_midlevel_set
(
'midlevel'
,
path_midlevel_annotations
,
path_midlevel_annotations
,
path_midlevel_audio_dir
,
path_midlevel_audio_dir
,
...
@@ -177,24 +177,23 @@ class BasePtlModel(pl.LightningModule):
...
@@ -177,24 +177,23 @@ class BasePtlModel(pl.LightningModule):
@
pl
.
data_loader
@
pl
.
data_loader
def
tng_dataloader
(
self
):
def
tng_dataloader
(
self
):
if
self
.
data_source
==
'mtgjamendo'
:
if
self
.
data_source
==
'mtgjamendo'
:
dataset
=
df_get_mtg_set
(
'mtgjamendo'
,
dataset
=
df_get_mtg_set
(
'mtgjamendo'
,
path_mtgjamendo_annotations_train
,
path_mtgjamendo_annotations_train
,
path_mtgjamendo_audio_dir
,
path_mtgjamendo_audio_dir
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
slice_len
=
self
.
input_size
)
slice_len
=
self
.
input_size
)
elif
self
.
data_source
==
'midlevel'
:
elif
self
.
data_source
==
'midlevel'
:
dataset
=
self
.
midlevel_trainset
dataset
=
self
.
midlevel_trainset
else
:
else
:
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
if
self
.
slicing_mode
==
'full'
:
if
self
.
slicing_mode
==
'full'
:
return
DataLoader
(
dataset
=
dataset
,
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
collate_fn
=
PadSequence
())
collate_fn
=
PadSequence
())
else
:
return
DataLoader
(
dataset
=
dataset
,
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
shuffle
=
True
)
...
@@ -207,10 +206,8 @@ class BasePtlModel(pl.LightningModule):
...
@@ -207,10 +206,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir
,
path_mtgjamendo_audio_dir
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
slice_len
=
self
.
input_size
)
slice_len
=
self
.
input_size
)
elif
self
.
data_source
==
'midlevel'
:
elif
self
.
data_source
==
'midlevel'
:
dataset
=
self
.
midlevel_valset
dataset
=
self
.
midlevel_valset
else
:
else
:
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
...
@@ -219,7 +216,7 @@ class BasePtlModel(pl.LightningModule):
...
@@ -219,7 +216,7 @@ class BasePtlModel(pl.LightningModule):
batch_size
=
self
.
hparams
.
batch_size
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
collate_fn
=
PadSequence
())
collate_fn
=
PadSequence
())
else
:
return
DataLoader
(
dataset
=
dataset
,
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
shuffle
=
True
)
...
@@ -232,10 +229,8 @@ class BasePtlModel(pl.LightningModule):
...
@@ -232,10 +229,8 @@ class BasePtlModel(pl.LightningModule):
path_mtgjamendo_audio_dir
,
path_mtgjamendo_audio_dir
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
"_ap_mtgjamendo44k"
,
slicing_func
=
self
.
slicer
,
slice_len
=
self
.
input_size
)
slice_len
=
self
.
input_size
)
elif
self
.
data_source
==
'midlevel'
:
elif
self
.
data_source
==
'midlevel'
:
dataset
=
self
.
midlevel_testset
dataset
=
self
.
midlevel_testset
else
:
else
:
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
raise
Exception
(
f
"Data source
{
self
.
data_source
}
not defined"
)
...
@@ -244,7 +239,7 @@ class BasePtlModel(pl.LightningModule):
...
@@ -244,7 +239,7 @@ class BasePtlModel(pl.LightningModule):
batch_size
=
self
.
hparams
.
batch_size
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
collate_fn
=
PadSequence
())
collate_fn
=
PadSequence
())
else
:
return
DataLoader
(
dataset
=
dataset
,
return
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
hparams
.
batch_size
,
batch_size
=
self
.
hparams
.
batch_size
,
shuffle
=
True
)
shuffle
=
True
)
results/tag_freq_test.png
deleted
100644 → 0
View file @
bc052327
193 KB
results/tag_freq_train.png
deleted
100644 → 0
View file @
bc052327
213 KB
results/tag_freq_validation.png
deleted
100644 → 0
View file @
bc052327
206 KB
tagslist.npy
deleted
100644 → 0
View file @
bc052327
File deleted
utils.py
View file @
ddf6c7cc
...
@@ -70,6 +70,8 @@ elif hostname == 'verena-830g5': # Laptop Verena
...
@@ -70,6 +70,8 @@ elif hostname == 'verena-830g5': # Laptop Verena
USE_GPU
=
False
USE_GPU
=
False
elif
hostname
==
'shreyan-HP'
:
# Laptop Shreyan
elif
hostname
==
'shreyan-HP'
:
# Laptop Shreyan
USE_GPU
=
False
USE_GPU
=
False
PATH_DATA_CACHE
=
'/home/shreyan/mounts/home@rk2/shared/kofta_cached_datasets'
else
:
else
:
# PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
# PATH_DATA_CACHE = '/home/shreyan/mounts/home@rk3/shared/kofta_cached_datasets'
PATH_DATA_CACHE
=
'/mnt/2tb/datasets/data_caches'
PATH_DATA_CACHE
=
'/mnt/2tb/datasets/data_caches'
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment