Commit fb6183fc authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

implement crnn, make BasePtlModel base class

parent 4def4edd
import os
import getpass
hostname = os.uname()[1]
username = getpass.getuser()
if hostname == 'rechenknecht1.cp.jku.at':
path_data_cache = '/media/rk1/shared/kofta_cached_datasets'
# midlevel
path_midlevel_annotations_dir = '/media/rk1/shared/datasets/midlevel/metadata_annotations'
path_midlevel_annotations = '/media/rk1/shared/datasets/midlevel/metadata_annotations/annotations.csv'
path_midlevel_audio_dir = '/media/rk1/shared/datasets/midlevel/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir = '/media/rk1/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train = '/media/rk1/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val = '/media/rk1/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/val_processed.tsv'
path_mtgjamendo_annotations_test = '/media/rk1/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = '/media/rk1/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname == 'rechenknecht2.cp.jku.at':
path_data_cache = '/media/rk2/shared/kofta_cached_datasets'
# midlevel
path_midlevel_annotations_dir = '/media/rk2/shared/datasets/midlevel/metadata_annotations'
path_midlevel_annotations = '/media/rk2/shared/datasets/midlevel/metadata_annotations/annotations.csv'
path_midlevel_audio_dir = '/media/rk2/shared/datasets/midlevel/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir = '/media/rk2/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train = '/media/rk2/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val = '/media/rk2/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv'
path_mtgjamendo_annotations_test = '/media/rk2/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = '/media/rk2/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname == 'rechenknecht3.cp.jku.at':
path_data_cache = '/media/rk3/shared/kofta_cached_datasets'
# midlevel
path_midlevel_annotations_dir = '/media/rk3/shared/datasets/midlevel/metadata_annotations'
path_midlevel_annotations = '/media/rk3/shared/datasets/midlevel/metadata_annotations/annotations.csv'
path_midlevel_audio_dir = '/media/rk3/shared/datasets/midlevel/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir = '/media/rk3/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train = '/media/rk3/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val = '/media/rk3/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv'
path_mtgjamendo_annotations_test = '/media/rk3/shared/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = '/media/rk3/shared/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname == 'shreyan-All-Series':
path_data_cache = '/mnt/2tb/datasets/data_caches'
# midlevel
path_midlevel_annotations_dir = '/mnt/2tb/datasets/MidlevelFeatures/metadata_annotations'
path_midlevel_annotations = '/mnt/2tb/datasets/MidlevelFeatures/metadata_annotations/annotations.csv'
path_midlevel_audio_dir = '/mnt/2tb/datasets/MidlevelFeatures/audio'
# mtgjamendo
path_mtgjamendo_annotations_dir = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations'
path_mtgjamendo_annotations_train = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/train_processed.tsv'
path_mtgjamendo_annotations_val = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/validation_processed.tsv'
path_mtgjamendo_annotations_test = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_annotations/test_processed.tsv'
path_mtgjamendo_audio_dir = '/mnt/2tb/datasets/MTG-Jamendo/MTG-Jamendo_audio'
if hostname not in ['rechenknecht1.cp.jku.at', 'rechenknecht2.cp.jku.at', 'rechenknecht3.cp.jku.at', 'shreyan-All-Series']:
raise Exception(f"Paths not defined for {hostname}")
\ No newline at end of file
......@@ -83,8 +83,9 @@ def run(hparams):
experiment=exp, max_nb_epochs=config['epochs'], train_percent_check=hparams.train_percent,
fast_dev_run=False,
early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback
)
checkpoint_callback=checkpoint_callback,
nb_sanity_val_steps=0) # don't run sanity validation run
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.01,
fast_dev_run=True)
......
from utils import USE_GPU, init_experiment
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from test_tube import Experiment, HyperOptArgumentParser
from models.crnn import CRNN as Network
import os
model_config = {
'data_source':'mtgjamendo',
'validation_metrics':['rocauc', 'prauc'],
'test_metrics':['rocauc', 'prauc']
}
initialized = False # TODO: Find a better way to do this
def run(hparams):
if not initialized:
init_experiment(comment=hparams.experiment_name)
from utils import CURR_RUN_PATH, logger # import these after init_experiment
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
logger.info(hparams)
exp = Experiment(save_dir=CURR_RUN_PATH)
# callbacks
early_stop = EarlyStopping(
monitor='val_loss',
patience=20,
verbose=True,
mode='min'
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(CURR_RUN_PATH, 'best.ckpt'),
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min'
)
if USE_GPU:
trainer = Trainer(gpus=[0], distributed_backend=None,
experiment=exp, max_nb_epochs=hparams.max_epochs,
train_percent_check=hparams.train_percent,
fast_dev_run=False, early_stop_callback=early_stop,
checkpoint_callback=checkpoint_callback,
nb_sanity_val_steps=0) # don't run sanity validation run
else:
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1,
fast_dev_run=True)
model = Network(num_class=56, config=model_config, hparams=hparams)
print(model)
trainer.fit(model)
trainer.test()
if __name__=='__main__':
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
parent_parser.add_argument('--experiment_name', type=str,
default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--train_percent', type=float,
default=1.0, help='how much train data to use')
parent_parser.add_argument('--max_epochs', type=int,
default=10, help='maximum number of epochs')
#parent_parser.add_argument('--gpus', type=list, default=[0,1],
# help='how many gpus to use in the node.'
# ' value -1 uses all the gpus on the node')
parser = Network.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
# run(hyperparams)
#gpus = ['cuda:0', 'cuda:1']
#hyperparams.optimize_parallel_gpu(run, gpus, 5)
for hparam_trial in hyperparams.trials(20):
run(hparam_trial)
......@@ -47,25 +47,9 @@ def pretrain_midlevel(hparams):
logger.info(f"tensorboard --logdir={CURR_RUN_PATH}")
exp = Experiment(name='midlevel', save_dir=CURR_RUN_PATH)
def setup_config():
def print_config():
global config
st = '---------CONFIG--------\n'
for k in config.keys():
st += k+':'+str(config.get(k))+'\n'
return st
conf = hparams.config
if conf is not None:
conf_func = globals()[conf]
try:
conf_func()
except:
logger.error(f"Config {conf} not defined")
logger.info(print_config())
setup_config()
midlevel_configs()
logger.info(print_config())
early_stop = EarlyStopping(
monitor=config['earlystopping_metric'],
......
import math
import sys
import time
import numpy as np
import wave
import scipy
import scipy.signal
from pylab import *
import array
import os
from os.path import expanduser
import scipy.io.wavfile
# Author: Brian K. Vogel
# brian.vogel@gmail.com
def hz_to_mel(f_hz):
"""Convert Hz to mel scale.
This uses the formula from O'Shaugnessy's book.
Args:
f_hz (float): The value in Hz.
Returns:
The value in mels.
"""
return 2595*np.log10(1.0 + f_hz/700.0)
def mel_to_hz(m_mel):
"""Convert mel scale to Hz.
This uses the formula from O'Shaugnessy's book.
Args:
m_mel (float): The value in mels
Returns:
The value in Hz
"""
return 700*(10**(m_mel/2595) - 1.0)
def fft_bin_to_hz(n_bin, sample_rate_hz, fft_size):
"""Convert FFT bin index to frequency in Hz.
Args:
n_bin (int or float): The FFT bin index.
sample_rate_hz (int or float): The sample rate in Hz.
fft_size (int or float): The FFT size.
Returns:
The value in Hz.
"""
n_bin = float(n_bin)
sample_rate_hz = float(sample_rate_hz)
fft_size = float(fft_size)
return n_bin*sample_rate_hz/(2.0*fft_size)
def hz_to_fft_bin(f_hz, sample_rate_hz, fft_size):
"""Convert frequency in Hz to FFT bin index.
Args:
f_hz (int or float): The frequency in Hz.
sample_rate_hz (int or float): The sample rate in Hz.
fft_size (int or float): The FFT size.
Returns:
The FFT bin index as an int.
"""
f_hz = float(f_hz)
sample_rate_hz = float(sample_rate_hz)
fft_size = float(fft_size)
fft_bin = int(np.round((f_hz*2.0*fft_size/sample_rate_hz)))
if fft_bin >= fft_size:
fft_bin = fft_size-1
return fft_bin
def make_mel_filterbank(min_freq_hz, max_freq_hz, mel_bin_count,
linear_bin_count, sample_rate_hz):
"""Create a mel filterbank matrix.
Create and return a mel filterbank matrix `filterbank` of shape (`mel_bin_count`,
`linear_bin_couont`). The `filterbank` matrix can be used to transform a
(linear scale) spectrum or spectrogram into a mel scale spectrum or
spectrogram as follows:
`mel_scale_spectrum` = `filterbank`*'linear_scale_spectrum'
where linear_scale_spectrum' is a shape (`linear_bin_count`, `m`) and
`mel_scale_spectrum` is shape ('mel_bin_count', `m`) where `m` is the number
of spectral time slices.
Likewise, the reverse-direction transform can be performed as:
'linear_scale_spectrum' = filterbank.T`*`mel_scale_spectrum`
Note that the process of converting to mel scale and then back to linear
scale is lossy.
This function computes the mel-spaced filters such that each filter is triangular
(in linear frequency) with response 1 at the center frequency and decreases linearly
to 0 upon reaching an adjacent filter's center frequency. Note that any two adjacent
filters will overlap having a response of 0.5 at the mean frequency of their
respective center frequencies.
Args:
min_freq_hz (float): The frequency in Hz corresponding to the lowest
mel scale bin.
max_freq_hz (flloat): The frequency in Hz corresponding to the highest
mel scale bin.
mel_bin_count (int): The number of mel scale bins.
linear_bin_count (int): The number of linear scale (fft) bins.
sample_rate_hz (float): The sample rate in Hz.
Returns:
The mel filterbank matrix as an 2-dim Numpy array.
"""
min_mels = hz_to_mel(min_freq_hz)
max_mels = hz_to_mel(max_freq_hz)
# Create mel_bin_count linearly spaced values between these extreme mel values.
mel_lin_spaced = np.linspace(min_mels, max_mels, num=mel_bin_count)
# Map each of these mel values back into linear frequency (Hz).
center_frequencies_hz = np.array([mel_to_hz(n) for n in mel_lin_spaced])
mels_per_bin = float(max_mels - min_mels)/float(mel_bin_count - 1)
mels_start = min_mels - mels_per_bin
hz_start = mel_to_hz(mels_start)
fft_bin_start = hz_to_fft_bin(hz_start, sample_rate_hz, linear_bin_count)
#print('fft_bin_start: ', fft_bin_start)
mels_end = max_mels + mels_per_bin
hz_stop = mel_to_hz(mels_end)
fft_bin_stop = hz_to_fft_bin(hz_stop, sample_rate_hz, linear_bin_count)
#print('fft_bin_stop: ', fft_bin_stop)
# Map each center frequency to the closest fft bin index.
linear_bin_indices = np.array([hz_to_fft_bin(f_hz, sample_rate_hz, linear_bin_count) for f_hz in center_frequencies_hz])
# Create filterbank matrix.
filterbank = np.zeros((mel_bin_count, linear_bin_count))
for mel_bin in range(mel_bin_count):
center_freq_linear_bin = linear_bin_indices[mel_bin]
# Create a triangular filter having the current center freq.
# The filter will start with 0 response at left_bin (if it exists)
# and ramp up to 1.0 at center_freq_linear_bin, and then ramp
# back down to 0 response at right_bin (if it exists).
# Create the left side of the triangular filter that ramps up
# from 0 to a response of 1 at the center frequency.
if center_freq_linear_bin > 1:
# It is possible to create the left triangular filter.
if mel_bin == 0:
# Since this is the first center frequency, the left side
# must start ramping up from linear bin 0 or 1 mel bin before the center freq.
left_bin = max(0, fft_bin_start)
else:
# Start ramping up from the previous center frequency bin.
left_bin = linear_bin_indices[mel_bin - 1]
for f_bin in range(left_bin, center_freq_linear_bin+1):
if (center_freq_linear_bin - left_bin) > 0:
response = float(f_bin - left_bin)/float(center_freq_linear_bin - left_bin)
filterbank[mel_bin, f_bin] = response
# Create the right side of the triangular filter that ramps down
# from 1 to 0.
if center_freq_linear_bin < linear_bin_count-2:
# It is possible to create the right triangular filter.
if mel_bin == mel_bin_count - 1:
# Since this is the last mel bin, we must ramp down to response of 0
# at the last linear freq bin.
right_bin = min(linear_bin_count - 1, fft_bin_stop)
else:
right_bin = linear_bin_indices[mel_bin + 1]
for f_bin in range(center_freq_linear_bin, right_bin+1):
if (right_bin - center_freq_linear_bin) > 0:
response = float(right_bin - f_bin)/float(right_bin - center_freq_linear_bin)
filterbank[mel_bin, f_bin] = response
filterbank[mel_bin, center_freq_linear_bin] = 1.0
return filterbank
def stft_for_reconstruction(x, fft_size, hopsamp):
"""Compute and return the STFT of the supplied time domain signal x.
Args:
x (1-dim Numpy array): A time domain signal.
fft_size (int): FFT size. Should be a power of 2, otherwise DFT will be used.
hopsamp (int):
Returns:
The STFT. The rows are the time slices and columns are the frequency bins.
"""
window = np.hanning(fft_size)
fft_size = int(fft_size)
hopsamp = int(hopsamp)
return np.array([np.fft.rfft(window*x[i:i+fft_size])
for i in range(0, len(x)-fft_size, hopsamp)])
def istft_for_reconstruction(X, fft_size, hopsamp):
"""Invert a STFT into a time domain signal.
Args:
X (2-dim Numpy array): Input spectrogram. The rows are the time slices and columns are the frequency bins.
fft_size (int):
hopsamp (int): The hop size, in samples.
Returns:
The inverse STFT.
"""
fft_size = int(fft_size)
hopsamp = int(hopsamp)
window = np.hanning(fft_size)
time_slices = X.shape[0]
len_samples = int(time_slices*hopsamp + fft_size)
x = np.zeros(len_samples)
for n,i in enumerate(range(0, len(x)-fft_size, hopsamp)):
x[i:i+fft_size] += window*np.real(np.fft.irfft(X[n]))
return x
def get_signal(in_file, expected_fs=44100):
"""Load a wav file.
If the file contains more than one channel, return a mono file by taking
the mean of all channels.
If the sample rate differs from the expected sample rate (default is 44100 Hz),
raise an exception.
Args:
in_file: The input wav file, which should have a sample rate of `expected_fs`.
expected_fs (int): The expected sample rate of the input wav file.
Returns:
The audio siganl as a 1-dim Numpy array. The values will be in the range [-1.0, 1.0]. fixme ( not yet)
"""
fs, y = scipy.io.wavfile.read(in_file)
num_type = y[0].dtype
if num_type == 'int16':
y = y*(1.0/32768)
elif num_type == 'int32':
y = y*(1.0/2147483648)
elif num_type == 'float32':
# Nothing to do
pass
elif num_type == 'uint8':
raise Exception('8-bit PCM is not supported.')
else:
raise Exception('Unknown format.')
if fs != expected_fs:
raise Exception('Invalid sample rate.')
if y.ndim == 1:
return y
else:
return y.mean(axis=1)
def reconstruct_signal_griffin_lim(magnitude_spectrogram, fft_size, hopsamp, iterations):
"""Reconstruct an audio signal from a magnitude spectrogram.
Given a magnitude spectrogram as input, reconstruct
the audio signal and return it using the Griffin-Lim algorithm from the paper:
"Signal estimation from modified short-time fourier transform" by Griffin and Lim,
in IEEE transactions on Acoustics, Speech, and Signal Processing. Vol ASSP-32, No. 2, April 1984.
Args:
magnitude_spectrogram (2-dim Numpy array): The magnitude spectrogram. The rows correspond to the time slices
and the columns correspond to frequency bins.
fft_size (int): The FFT size, which should be a power of 2.
hopsamp (int): The hope size in samples.
iterations (int): Number of iterations for the Griffin-Lim algorithm. Typically a few hundred
is sufficient.
Returns:
The reconstructed time domain signal as a 1-dim Numpy array.
"""
time_slices = magnitude_spectrogram.shape[0]
len_samples = int(time_slices*hopsamp + fft_size)
# Initialize the reconstructed signal to noise.
x_reconstruct = np.random.randn(len_samples)
n = iterations # number of iterations of Griffin-Lim algorithm.
while n > 0:
n -= 1
reconstruction_spectrogram = stft_for_reconstruction(x_reconstruct, fft_size, hopsamp)
reconstruction_angle = np.angle(reconstruction_spectrogram)
# Discard magnitude part of the reconstruction and use the supplied magnitude spectrogram instead.
proposal_spectrogram = magnitude_spectrogram*np.exp(1.0j*reconstruction_angle)
prev_x = x_reconstruct
x_reconstruct = istft_for_reconstruction(proposal_spectrogram, fft_size, hopsamp)
diff = sqrt(sum((x_reconstruct - prev_x)**2)/x_reconstruct.size)
print('Reconstruction iteration: {}/{} RMSE: {} '.format(iterations - n, iterations, diff))
return x_reconstruct
def save_audio_to_file(x, sample_rate, outfile='out.wav'):
"""Save a mono signal to a file.
Args:
x (1-dim Numpy array): The audio signal to save. The signal values should be in the range [-1.0, 1.0].
sample_rate (int): The sample rate of the signal, in Hz.
outfile: Name of the file to save.
"""
x_max = np.max(abs(x))
assert x_max <= 1.0, 'Input audio value is out of range. Should be in the range [-1.0, 1.0].'
x = x*32767.0
data = array.array('h')
for i in range(len(x)):
cur_samp = int(round(x[i]))
data.append(cur_samp)
f = wave.open(outfile, 'w')
f.setparams((1, 2, sample_rate, 0, "NONE", "Uncompressed"))
f.writeframes(data.tostring())
f.close()
\ No newline at end of file
from utils import *
import pytorch_lightning as pl
from models.shared_stuff import *
from sklearn import metrics
# TODO pr-auc
......@@ -87,7 +85,7 @@ class CNN(pl.LightningModule):
def training_step(self, data_batch, batch_nb):
x, _, y = data_batch
y_hat = self.forward_full_song(x, y)
y_hat = self.forward(x)
y = y.float()
y_hat = y_hat.float()
return {'loss':self.my_loss(y_hat, y)}
......
from test_tube import HyperOptArgumentParser
from utils import *
from models.shared_stuff import BasePtlModel
class CRNN(BasePtlModel):
def __init__(self, config, num_class, hparams):
super(CRNN, self).__init__(config, hparams)
# init bn
self.bn_init = nn.BatchNorm2d(1)
# layer 1
self.conv_1 = nn.Conv2d(1, 64, 3, padding=1)
self.bn_1 = nn.BatchNorm2d(64)
self.mp_1 = nn.MaxPool2d((2, 4))
# layer 2
self.conv_2 = nn.Conv2d(64, 128, 3, padding=1)
self.bn_2 = nn.BatchNorm2d(128)
self.mp_2 = nn.MaxPool2d((2, 4))
# layer 3
self.conv_3 = nn.Conv2d(128, 128, 3, padding=1)
self.bn_3 = nn.BatchNorm2d(128)
self.mp_3 = nn.MaxPool2d((2, 4))
# layer 4
self.conv_4 = nn.Conv2d(128, 128, 3, padding=1)
self.bn_4 = nn.BatchNorm2d(128)
self.mp_4 = nn.MaxPool2d((3, 5))
# layer 5
self.conv_5 = nn.Conv2d(128, 64, 3, padding=1)
self.bn_5 = nn.BatchNorm2d(64)
self.mp_5 = nn.MaxPool2d((4, 4))
# recurrent layer
self.gru1 = nn.GRU(input_size=32,
hidden_size=hparams.gru_hidden_size,
num_layers=hparams.gru_num_layers)
# classifier
self.dense = nn.Linear(hparams.gru_hidden_size, num_class)
self.dropout = nn.Dropout(self.hparams.drop_prob)
def forward(self, x):
x = x[:, :, :, :512]
# init bn
x = self.bn_init(x)
# layer 1
x = self.mp_1(nn.ELU()(self.bn_1(self.conv_1(x))))
# layer 2
x = nn.ELU()(self.bn_2(self.conv_2(x)))
# x = self.mp_2(nn.ELU()(self.bn_2(self.conv_2(x))))
# layer 3
x = self.mp_3(nn.ELU()(self.bn_3(self.conv_3(x))))
# layer 4
x = self.mp_4(nn.ELU()(self.bn_4(self.conv_4(x))))
# layer 5
x = self.mp_5(nn.ELU()(self.bn_5(self.conv_5(x))))
# classifier
x = x.view(-1, x.size(0), 32)
x = self.gru1(x)[1][1] # TODO: Check if this is correct
x = self.dropout(x)
logit = nn.Sigmoid()(self.dense(x))
return logit
@staticmethod
def add_model_specific_args(parent_parser):
"""Parameters defined here will be available to your model through self.hparams
"""
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
# network params
parser.add_argument('--gru_hidden_size', default=320, type=int)
parser.add_argument('--gru_num_layers', default=2, type=int)
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=True)
parser.opt_list('--learning_rate', default=0.0001, type=float,
options=[0.00001, 0.0005, 0.001],
tunable=True)
# training params (opt)
parser.opt_list('--optimizer_name', default='adam', type=str,
options=['adam'], tunable=False)
# if using 2 nodes with 4 gpus each the batch size here
# (256) will be 256 / (2*8) = 16 per gpu
parser.opt_list('--batch_size', default=32, type=int,
options=[32, 64, 128, 256], tunable=False,
help='batch size will be divided over all gpus being used across all nodes')
return parser
from datasets.midlevel import df_get_midlevel_set
from torch import optim
from torch.utils.data.dataset import random_split
from utils import PATH_ANNOTATIONS, PATH_AUDIO
import torch
import torch.nn.functional as F
......@@ -6,6 +9,8 @@ from sklearn import metrics
import os
from datasets.mtgjamendo import df_get_mtg_set
import numpy as np
import pytorch_lightning as pl
from datasets.shared_data_utils import *
def my_loss(y_hat, y):
return F.binary_cross_entropy(y_hat, y)
......@@ -141,4 +146,210 @@ def test_dataloader(batch_size=32):
dataset = df_get_mtg_set('mtgjamendo_test', test_csv, PATH_AUDIO, cache_x_name)
return DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
\ No newline at end of file