import os import time import hashlib import datetime import torch import torch.nn as nn import logging import numpy as np import pandas as pd import pytorch_lightning as ptl from matplotlib import pyplot as plt from tqdm import tqdm plt.rcParams["figure.dpi"] = 288 # increase dpi for clearer plots from plotting import * # mostly for debug # PARAMS ======================= INPUT_SIZE = (96, 256) MAX_FRAMES = 40 # CONFIG ======================= # paths: hostname = os.uname()[1] if hostname in ['rechenknecht0.cp.jku.at', 'rechenknecht1.cp.jku.at', 'rechenknecht3.cp.jku.at']: plt.switch_backend('agg') PATH_DATA_ROOT = '/home/shreyan/data/MTG-Jamendo' USE_GPU = True else: PATH_DATA_ROOT = '/mnt/2tb/datasets/MTG-Jamendo' USE_GPU = False PATH_PROJECT_ROOT = os.path.dirname(os.path.realpath(__file__)) PATH_AUDIO = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_audio') PATH_ANNOTATIONS = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_annotations') PATH_MELSPEC_DOWNLOADED = os.path.join(PATH_DATA_ROOT, 'MTG-Jamendo_melspec_downloaded') PATH_MELSPEC_DOWNLOADED_FRAMED = os.path.join(PATH_MELSPEC_DOWNLOADED, 'framed') PATH_RESULTS = os.path.join(PATH_PROJECT_ROOT, 'results') TRAINED_MODELS_PATH = '' # run name def make_run_name(suffix=''): assert ' ' not in suffix hash = hashlib.sha1() hash.update(str(time.time()).encode('utf-8')) run_hash = hash.hexdigest()[:5] name = run_hash + suffix return name curr_run_name = make_run_name() CURR_RUN_PATH = os.path.join(PATH_RESULTS, 'runs', curr_run_name) if not os.path.isdir(CURR_RUN_PATH): os.mkdir(CURR_RUN_PATH) # SET UP LOGGING ============================================= filelog = logging.getLogger() streamlog = logging.getLogger() logger = logging.getLogger() fh = logging.FileHandler(os.path.join(CURR_RUN_PATH, f'{curr_run_name}.log')) sh = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s') fh.setFormatter(formatter) sh.setFormatter(formatter) # filelog logs only to file filelog.addHandler(fh) filelog.setLevel(logging.INFO) # streamlog logs only to terminal streamlog.addHandler(sh) streamlog.setLevel(logging.INFO) # logger logs to both file and terminal logger.addHandler(fh) logger.addHandler(sh) logger.setLevel(logging.DEBUG) # ============================================ def write_to_file(data, path): # not fully implemented. unused function as of now. with open(path, 'w') as f: if isinstance(data, np.ndarray): for i in data: f.writelines(i) def dims_calc(obj, in_shape): """ utility function to calculate output dimensions of a conv2d or maxpool2d stage """ kernel_size = obj.kernel_size stride = obj.stride padding = obj.padding dilation = obj.dilation h_in = in_shape[0] w_in = in_shape[1] if isinstance(obj, nn.Conv2d): h_out = int(((h_in + 2*padding[0] - dilation[0]*(kernel_size[0]-1))/stride[0])+1) w_out = int(((w_in + 2*padding[1] - dilation[1]*(kernel_size[1]-1))/stride[1])+1) out_shape = [h_out, w_out, obj.out_channels] elif isinstance(obj, nn.MaxPool2d): if isinstance(padding, int): padding = (padding, padding) if isinstance(dilation, int): dilation = (dilation, dilation) if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) h_out = int(((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) -1) / stride[0]) + 1) w_out = int(((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1)-1) / stride[1]) + 1) out_shape = [h_out, w_out, in_shape[2]] else: out_shape = [None, None, None] return out_shape def preprocess_and_save_annotation_files(): """ Removes 'mood/theme---' from tag names, and replaces tabs between multiple tag names with commas. Writes processed filename.ext as filename_processed.ext """ import re filelist = os.listdir(PATH_ANNOTATIONS) for file in filelist: # Check if the current file is processed or has a processed copy. # Skip if either of this is true. Else process. if 'processed' in os.path.splitext(file)[0].split('_') or\ f'{os.path.splitext(file)[0]}_processed{os.path.splitext(file)[1]}' in filelist: continue else: with open(os.path.join(PATH_ANNOTATIONS, file), 'r') as f: text = f.read() text = re.sub(r'mood/theme---(\w*)\n', r'\1\n', text) # matches last or singular tags text = re.sub(r'mood/theme---(\w*)(\s*)', r'\1,', text) # matches all other tags with open(os.path.join(PATH_ANNOTATIONS, f'{os.path.splitext(file)[0]}_processed{os.path.splitext(file)[1]}'), 'w') as fw: fw.write(text) def trim_silence(spec, thresh=0.1): """ Trims silence from the beginning and end of a song spectrogram based on a threshold applied to the median loudness. Loudness is calculated by summing the magnitudes over the frequency axis for each time frame. """ loudness = np.sum(spec, axis=0) loudness = loudness - np.min(loudness) cutoff = thresh*np.median(loudness) start = 0 end = len(loudness) for i in range(len(loudness)): if loudness[i] > cutoff: start = i break for i in range(len(loudness)-1, start, -1): if loudness[i] > cutoff: end = i break return spec[:,start:end] def make_framed_spec(spec, frame_length, total_frames=None, hop=0.5, discard_end=False, filler='wrap'): """ Given a spectrogram of an entire song, this function splits it into frames and returns a torch tensor with an additional dimension (frame number) and the framed spectrogram chunks. Each frame is meant to be directly fed into a model input of matching size. """ assert filler in ['wrap', 'pad'], logger.error(f"filler is {filler}, must be either wrap or pad") fstart = 0 fend = int(frame_length) framed_spec = [] while fend < spec.shape[1]: framed_spec.append(spec[:,fstart:fend]) fstart += int(hop*frame_length) fend = fstart + int(frame_length) if not discard_end: framed_spec.append(spec[:,-frame_length:]) if total_frames is not None: if len(framed_spec) > total_frames: framed_spec = framed_spec[:total_frames] else: if filler in ['wrap']: # Wrap around while len(framed_spec) < total_frames: framed_spec.extend(framed_spec[0:total_frames-len(framed_spec)]) else: # Pad with silence silence = np.zeros(spec[:,0:frame_length].shape) while len(framed_spec) < total_frames: framed_spec.extend(silence) framed_spec = torch.from_numpy(np.array(framed_spec)) return framed_spec def preprocess_specs(source_root, destination_root, frame_length=256, hop=1.0): """ Reads spectrograms from source_root and performs: - trim_silence() - make_framed_spec() and saves the resulting framed spectrograms to destination_root """ if not os.path.exists(destination_root): os.mkdir(destination_root) filelist = os.walk(source_root) for dirpath, _, filenames in filelist: # Ignore dir of framed specs if dirpath is destination_root: continue for filename in tqdm(filenames): destination_subdir = os.path.join(destination_root, dirpath.split('/')[-1]) if os.path.exists(os.path.join(destination_subdir, filename)): # If framed melspec already exists, don't preprocess continue else: if not os.path.exists(destination_subdir): os.mkdir(destination_subdir) destination = os.path.join(destination_subdir, filename) spec = np.load(os.path.join(dirpath, filename)) spec = trim_silence(spec) framed_spec = make_framed_spec(spec, frame_length=frame_length, hop=hop) np.save(destination, framed_spec) if __name__=='__main__': # TESTS # c = nn.Conv2d(512, 256, 1, 1, 0) # (in_channels, out_channels, kernel_size, stride, padding) # m = nn.MaxPool2d(2) # print(dims_calc(c, [37, 17, 512])) preprocess_specs(source_root=PATH_MELSPEC_DOWNLOADED, destination_root=PATH_MELSPEC_DOWNLOADED_FRAMED) pass