Commit dccd0755 authored by Richard Vogl's avatar Richard Vogl
Browse files

fix loading data with test set

parent 00ac5936
......@@ -10,7 +10,7 @@ SETTINGS_FILE_NAME = 'settings.npy'
LOSSES_FILE = 'losses.npy'
LIBROSA_PATH = os.path.join(DATASET_PATH, 'librosa_disklavier')
CACHED_DS_NAME = 'cached.npy'
CACHED_DS_NAME = 'cached.npz'
if not os.path.exists(OUTPUT_PATH):
......
......@@ -47,28 +47,32 @@ def load_ramdom_data():
def load_data():
data_dict = np.load(os.path.join(LIBROSA_PATH, CACHED_DS_NAME))
train_name_list = data_dict['train_name_list']
train_feat_list = data_dict['train_feat_list']
train_targ_list = data_dict['train_targ_list']
train_names = data_dict['train_names']
train_feats = data_dict['train_feats']
train_targs = data_dict['train_targs']
test_name_list = data_dict['test_name_list']
test_feat_list = data_dict['test_feat_list']
test_targ_list = data_dict['test_targ_list']
test_names = data_dict['test_names']
test_feats = data_dict['test_feats']
test_targs = data_dict['test_targs']
train_size = len(train_feat_list)
train_size = len(train_feats)
valid_size = int(train_size * 0.15)
valid_idx = np.random.choice(range(train_size), valid_size)
valid_name_list = train_name_list[valid_idx]
valid_feat_list = train_feat_list[valid_idx]
valid_targ_list = train_targ_list[valid_idx]
valid_names = train_names[valid_idx]
valid_feats = train_feats[valid_idx]
valid_targs = train_targs[valid_idx]
train_name_list = np.delete(train_name_list, valid_idx)
train_feat_list = np.delete(train_feat_list, valid_idx)
train_targ_list = np.delete(train_targ_list, valid_idx)
train_names = np.delete(train_names, valid_idx)
train_feats = np.delete(train_feats, valid_idx)
train_targs = np.delete(train_targs, valid_idx)
return train_feat_list, train_targ_list, valid_feat_list, valid_targ_list, test_feat_list, test_targ_list,
max_seq_len = np.max([tf.shape[0] for tf in train_feats])
feat_size = train_feats[0].shape[1]
out_size = train_targs[0].shape[1]
return train_feats, train_targs, valid_feats, valid_targs, test_feats, test_targs, max_seq_len, feat_size, out_size
if __name__ == '__main__':
......
......@@ -32,7 +32,7 @@ def run(model, model_name, model_name_full, learn_rate, batch_size, split, k_sam
os.makedirs(out_directory)
print("Loading data...")
feat_train, targ_train, feat_valid, targ_valid, max_seq_len, feat_size, out_len = load_data(split)
feat_train, targ_train, feat_valid, targ_valid, feat_test, targ_test, max_seq_len, feat_size, out_len = load_data()
print("Building model...")
network = model.build_model(batch_size=batch_size, seq_length=model.SEQ_LENGTH, feat_len=feat_size, out_len=out_len)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment