Commit d671e94b authored by Verena Praher's avatar Verena Praher

add data loading for mtg dataset

parent 62972be1
......@@ -21,6 +21,18 @@ def initialize_weights(module):
module.bias.data.zero_()
def get_mtg_sets():
data_root, audio_path, csvs_path = get_paths('mtgjamendo')
cache_x_name = '_ap_mtgjamendo44k'
train_csv = os.path.join(csvs_path, 'train_processed.tsv')
validation_csv = os.path.join(csvs_path, 'validation_processed.tsv')
test_csv = os.path.join(csvs_path, 'test_processed.tsv')
trainset = df_get_mtg_set('mtgjamendo', train_csv, audio_path, cache_x_name)
validationset = df_get_mtg_set('mtgjamendo_val', validation_csv, audio_path, cache_x_name)
testset = df_get_mtg_set('mtgjamendo_test', test_csv, audio_path, cache_x_name)
return trainset, validationset, testset
def get_midlevel_sets():
data_root, audio_path, csvs_path = get_paths('midlevel')
cache_x_name = '_ap_midlevel44k'
......@@ -38,7 +50,7 @@ class ModelMidlevel(BasePtlModel):
super(ModelMidlevel, self).__init__(config, hparams)
self.midlevel_trainset, self.midlevel_valset, self.midlevel_testset = get_midlevel_sets()
# TODO same for mtg sets
self.mtg_trainset, self.mtg_valset, self.mtg_testset = get_mtg_sets()
self.num_targets = num_targets
self.conv1 = nn.Sequential(
......@@ -111,18 +123,7 @@ class ModelMidlevel(BasePtlModel):
if initialize:
self.apply(initialize_weights)
#if load_from:
# self._load_model(load_from, map_location, on_gpu)
# if dataset == 'mtgjamendo':
self.fc_mtg1 = nn.Linear(256, 56)
# self.fc_mtg2 = nn.Linear(10, 56)
# for name, param in self.named_parameters():
# if 'mtg' in name:
# param.requires_grad = True
# else:
# param.requires_grad = False
def forward(self, x, dataset_i=None):
# 313 * 149 * 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment