mtgjamendo.py 3.95 KB
Newer Older
1
import os
2
from datasets.dataset import H5FCachedDataset, AudioPreprocessDataset
3
4
5
6
import torch
import librosa
import numpy as np
import pandas as pd
7
8
from utils import PATH_DATA_CACHE

9
slice_length = 512 #TODO: Find a better way
10
def sample_slicing_function(h5data, idx, xlen):
11
    timeframes = slice_length
12
13
14
15
16
    k = torch.randint(xlen - timeframes + 1, (1,))[0].item()
    x = h5data[idx + k:idx + k + timeframes]

    return torch.from_numpy(x.transpose(1, 0).reshape(1, 256, timeframes))

17
18
19
20
21
22
23
24
25
26
27
28
29
30
def full_song_slicing_function(h5data, idx, xlen):
    #TODO: not working, make it work if possible.
    maxlen = 2048
    if xlen > maxlen:
        k = torch.randint(xlen - maxlen + 1, (1,))[0].item()
        x = h5data[idx + k:idx + k + maxlen]
        print(x.shape)
    else:
        x = h5data[idx:idx+xlen]
        x = np.pad(x, ((0, maxlen - xlen), (0, 0)), mode='wrap')
        print(x.shape)

    return torch.from_numpy(x.transpose((1, 0)).reshape((1, 256, -1)))

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
t2_parse_labels_cache = {}


def t2_parse_labels(csvf):
    global t2_parse_labels_cache
    if t2_parse_labels_cache.get(csvf) is not None:
        return t2_parse_labels_cache.get(csvf)
    df = pd.read_csv(csvf, sep='\t')
    files = df['PATH'].values
    labels = []
    for l in df['TAGS'].values:
        labels.append(set(l.split(",")))
    from sklearn.preprocessing import MultiLabelBinarizer
    mlb = MultiLabelBinarizer()
    bins = mlb.fit_transform(labels)
    t2_parse_labels_cache[csvf] = files, bins, mlb
    return t2_parse_labels_cache[csvf]


def processor_mtgjamendo44k(file_path):
    n_fft = 2048  # 2048
    sr = 44100  # 22050  # 44100  # 32000
    mono = True  # @todo ask mattias
    log_spec = False
    n_mels = 256

    hop_length = 512
    fmax = None
    dpath, filename = os.path.split(file_path)
    #file_path2 = dpath + "/../audio22k/" + filename

    if mono:
        # this is the slowest part resampling
        sig, sr = librosa.load(file_path, sr=sr, mono=True)

        sig = sig[np.newaxis]

    else:
        sig, sr = librosa.load(file_path, sr=sr, mono=False)
        # sig, sf_sr = sf.read(file_path)
        # sig = np.transpose(sig, (1, 0))
        # sig = np.asarray([librosa.resample(s, sf_sr, sr) for s in sig])

    spectrograms = []
    for y in sig:

        # compute stft
        stft = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=None, window='hann', center=True,
                            pad_mode='reflect')

        # keep only amplitures
        stft = np.abs(stft)

        # spectrogram weighting
        if log_spec:
            stft = np.log10(stft + 1)
        else:
            freqs = librosa.core.fft_frequencies(sr=sr, n_fft=n_fft)
            stft = librosa.perceptual_weighting(stft ** 2, freqs, ref=1.0, amin=1e-10, top_db=80.0)

        # apply mel filterbank
        spectrogram = librosa.feature.melspectrogram(S=stft, sr=sr, n_mels=n_mels, fmax=fmax)

        # keep spectrogram
        spectrograms.append(np.asarray(spectrogram))

    spectrograms = np.asarray(spectrograms, dtype=np.float32)

    return torch.from_numpy(spectrograms)


audio_processor = processor_mtgjamendo44k
label_encoder = None


106
107
108
def df_get_mtg_set(name, mtg_files_csv, audio_path, cache_x_name, slicing_func=None, slice_len=512, augment_options=None):
    global slice_length
    slice_length = slice_len
109
110
111
112
113
114
115
116
    audio_path = os.path.expanduser(audio_path)
    global label_encoder
    print("loading dataset from '{}'".format(name))

    def getdatset():
        files, labels, label_encoder = t2_parse_labels(mtg_files_csv)
        return AudioPreprocessDataset(files, labels, label_encoder, audio_path, audio_processor)

117
118
119
120
    if slicing_func is None:
        slicing_func = sample_slicing_function

    df_trset = H5FCachedDataset(getdatset, name, slicing_function=slicing_func,
121
                                x_name=cache_x_name,
122
123
                                cache_path=PATH_DATA_CACHE,
                                augment_options=augment_options
124
125
126
127
                                )

    return df_trset

128