Commit 9f0e4820 authored by Shreyan Chowdhury's avatar Shreyan Chowdhury
Browse files

add project files

parent e3345094
*.pyc
*.idea*
name: midlevel_da
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=1_llvm
- appdirs=1.4.3=py_1
- argon2-cffi=20.1.0=py38h1e0a361_1
- async_generator=1.10=py_0
- attrs=20.2.0=pyh9f0ad1d_0
- audioread=2.1.8=py38h32f6830_2
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=py_2
- backports.functools_lru_cache=1.6.1=py_0
- blas=1.0=mkl
- bleach=3.2.1=pyh9f0ad1d_0
- brotlipy=0.7.0=py38h1e0a361_1000
- bzip2=1.0.8=h516909a_2
- ca-certificates=2020.6.20=hecda079_0
- certifi=2020.6.20=py38h32f6830_0
- cffi=1.14.1=py38h5bae8af_0
- chardet=3.0.4=py38h32f6830_1006
- cryptography=3.0=py38h766eaa4_0
- cudatoolkit=10.2.89=hfd86e86_1
- cycler=0.10.0=py_2
- cython=0.29.21=py38h950e882_0
- dbus=1.13.6=he372182_0
- decorator=4.4.2=py_0
- defusedxml=0.6.0=py_0
- entrypoints=0.3=py38h32f6830_1001
- expat=2.2.9=he1b5a44_2
- ffmpeg=4.3.1=h167e202_0
- fontconfig=2.13.1=h1056068_1002
- freetype=2.10.2=he06d7ca_0
- gettext=0.19.8.1=hc5be6a0_1002
- glib=2.58.3=py38h73cb85d_1004
- gmp=6.2.0=he1b5a44_2
- gnutls=3.6.13=h79a8f9a_0
- gst-plugins-base=1.14.5=h0935bb2_2
- gstreamer=1.14.5=h36ae1b5_2
- icu=67.1=he1b5a44_0
- importlib-metadata=2.0.0=py38h32f6830_0
- importlib_metadata=2.0.0=0
- intel-openmp=2020.1=217
- ipykernel=5.3.4=py38h23f93f0_0
- ipython=7.17.0=py38h1cdfbd6_0
- ipython_genutils=0.2.0=py_1
- ipywidgets=7.5.1=pyh9f0ad1d_1
- jedi=0.17.2=py38h32f6830_0
- jinja2=2.11.2=pyh9f0ad1d_0
- joblib=0.16.0=py_0
- jpeg=9d=h516909a_0
- jsonschema=3.2.0=py38h32f6830_1
- jupyter=1.0.0=py_2
- jupyter_client=6.1.6=py_0
- jupyter_console=6.2.0=py_0
- jupyter_core=4.6.3=py38h32f6830_1
- jupyterlab_pygments=0.1.1=pyh9f0ad1d_0
- kiwisolver=1.2.0=py38hbf85e49_0
- krb5=1.17.1=hfafb76e_3
- lame=3.100=h14c3975_1001
- lcms2=2.11=hbd6801e_0
- ld_impl_linux-64=2.34=hc38a660_9
- libblas=3.8.0=16_mkl
- libcblas=3.8.0=16_mkl
- libclang=10.0.1=default_hde54327_1
- libedit=3.1.20191231=h46ee950_1
- libevent=2.1.10=hcdb4288_2
- libffi=3.2.1=he1b5a44_1007
- libflac=1.3.3=he1b5a44_0
- libgcc-ng=9.3.0=h24d8f2e_14
- libgfortran-ng=7.5.0=hdf63c60_14
- libiconv=1.15=h516909a_1006
- liblapack=3.8.0=16_mkl
- libllvm10=10.0.1=he513fc3_3
- libllvm9=9.0.1=he513fc3_1
- libogg=1.3.2=h516909a_1002
- libpng=1.6.37=hed695b0_1
- libpq=12.3=h5513abc_0
- librosa=0.8.0=pyh9f0ad1d_0
- libsndfile=1.0.28=he1b5a44_1000
- libsodium=1.0.18=h516909a_0
- libstdcxx-ng=9.3.0=hdf63c60_14
- libtiff=4.1.0=hc7e4089_6
- libuuid=2.32.1=h14c3975_1000
- libvorbis=1.3.7=he1b5a44_0
- libwebp-base=1.1.0=h516909a_3
- libxcb=1.13=h14c3975_1002
- libxkbcommon=0.10.0=he1b5a44_0
- libxml2=2.9.10=h72b56ed_2
- llvm-openmp=10.0.1=hc9558a2_0
- llvmlite=0.33.0=py38h4f45e52_1
- lz4-c=1.9.2=he1b5a44_1
- markupsafe=1.1.1=py38h1e0a361_1
- matplotlib-base=3.3.0=py38h91b0d89_1
- mistune=0.8.4=py38h1e0a361_1001
- mkl=2020.2=256
- mkl-service=2.3.0=py38he904b0f_0
- mkl_fft=1.1.0=py38h23d657b_0
- mkl_random=1.1.1=py38hcb8c335_0
- mysql-common=8.0.21=2
- mysql-libs=8.0.21=hf3661c5_2
- nbclient=0.5.0=py_0
- nbconvert=6.0.6=py38h32f6830_0
- nbformat=5.0.7=py_0
- ncurses=6.2=he1b5a44_1
- nest-asyncio=1.4.0=py_1
- nettle=3.4.1=h1bed415_1002
- ninja=1.10.0=hc9558a2_0
- notebook=6.1.4=py38h32f6830_0
- nspr=4.29=he1b5a44_0
- nss=3.57=he751ad9_0
- numba=0.50.1=py38hcb8c335_1
- numpy=1.19.1=py38h8854b6b_0
- numpy-base=1.18.5=py38hde5b4d6_0
- olefile=0.46=py_0
- openh264=2.1.1=h8b12597_0
- openssl=1.1.1h=h516909a_0
- packaging=20.4=pyh9f0ad1d_0
- pandas=1.1.0=py38h950e882_0
- pandoc=2.10.1=h516909a_0
- pandocfilters=1.4.2=py_1
- parso=0.7.1=pyh9f0ad1d_0
- patsy=0.5.1=py_0
- pcre=8.44=he1b5a44_0
- pexpect=4.8.0=py38h32f6830_1
- pickleshare=0.7.5=py38h32f6830_1001
- pillow=7.2.0=py38h9776b28_1
- pip=20.2.1=py_0
- pooch=1.1.1=py_0
- prometheus_client=0.8.0=pyh9f0ad1d_0
- prompt-toolkit=3.0.7=py_0
- prompt_toolkit=3.0.7=0
- pthread-stubs=0.4=h14c3975_1001
- ptyprocess=0.6.0=py_1001
- pycparser=2.20=pyh9f0ad1d_2
- pygments=2.6.1=py_0
- pyopenssl=19.1.0=py_1
- pyparsing=2.4.7=pyh9f0ad1d_0
- pyqt=5.12.3=py38ha8c2ead_3
- pyrsistent=0.17.3=py38h1e0a361_0
- pysocks=1.7.1=py38h32f6830_1
- pysoundfile=0.10.2=py_1001
- python=3.8.5=h4d41432_2_cpython
- python-dateutil=2.8.1=py_0
- python_abi=3.8=1_cp38
- pytz=2020.1=pyh9f0ad1d_0
- pyzmq=19.0.2=py38ha71036d_0
- qt=5.12.9=h1f2b2cb_0
- qtconsole=4.7.7=pyh9f0ad1d_0
- qtpy=1.9.0=py_0
- readline=8.0=he28a2e2_2
- resampy=0.2.2=py_0
- scikit-learn=0.23.2=py38hee58b96_0
- scipy=1.5.2=py38h8c5af15_0
- seaborn=0.10.1=1
- seaborn-base=0.10.1=py_1
- send2trash=1.5.0=py_0
- setuptools=49.3.1=py38h32f6830_0
- six=1.15.0=pyh9f0ad1d_0
- sqlite=3.33.0=h4cf870e_0
- statsmodels=0.11.1=py38h1e0a361_2
- tbb=2020.1=hc9558a2_0
- terminado=0.9.1=py38h32f6830_0
- testpath=0.4.4=py_0
- threadpoolctl=2.1.0=pyh5ca1d4c_0
- tk=8.6.10=hed695b0_0
- tornado=6.0.4=py38h1e0a361_1
- tqdm=4.48.2=pyh9f0ad1d_0
- traitlets=4.3.3=py38h32f6830_1
- wcwidth=0.2.5=pyh9f0ad1d_1
- webencodings=0.5.1=py_1
- wheel=0.34.2=py_1
- widgetsnbextension=3.5.1=py38h32f6830_1
- x264=1!152.20180806=h14c3975_0
- xorg-libxau=1.0.9=h14c3975_0
- xorg-libxdmcp=1.1.3=h516909a_0
- xz=5.2.5=h516909a_1
- zeromq=4.3.2=he1b5a44_3
- zipp=3.2.0=py_0
- zlib=1.2.11=h516909a_1007
- zstd=1.4.5=h6597ccf_2
- pip:
- absl-py==0.9.0
- cachetools==4.1.0
- future==0.18.2
- google-auth==1.16.0
- google-auth-oauthlib==0.4.1
- grpcio==1.29.0
- idna==2.9
- madmom==0.16.1
- markdown==3.2.2
- mido==1.2.9
- oauthlib==3.1.0
- protobuf==3.12.2
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pyqt5-sip==4.19.18
- pyqtchart==5.12
- pyqtwebengine==5.12.1
- requests==2.23.0
- requests-oauthlib==1.3.0
- rsa==4.0
- tensorboard==2.2.2
- tensorboard-plugin-wit==1.6.0.post3
- torch==1.6.0
- torchaudio==0.5.1
- torchcontrib==0.0.2
- torchsummary==1.5.1
- torchvision==0.7.0
- urllib3==1.25.9
- werkzeug==1.0.1
name: midlevel_da
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- audioread=2.1.8=py38h32f6830_2
- blas=1.0=mkl
- bzip2=1.0.8=h516909a_2
- ca-certificates=2020.1.1=0
- certifi=2020.4.5.1=py38_0
- cffi=1.14.0=py38he30daa8_1
- cudatoolkit=10.2.89=hfd86e86_1
- cycler=0.10.0=py_2
- cython=0.29.17=py38he6710b0_0
- dbus=1.13.14=hb2f20db_0
- decorator=4.4.2=py_0
- expat=2.2.6=he6710b0_0
- ffmpeg=4.2.3=h167e202_0
- fontconfig=2.13.0=h9420a91_0
- freetype=2.9.1=h8a8886c_1
- gettext=0.19.8.1=h5e8e0c9_1
- glib=2.63.1=h3eb4bd4_1
- gmp=6.2.0=he1b5a44_2
- gnutls=3.6.13=h79a8f9a_0
- gst-plugins-base=1.14.0=hbbd80ab_1
- gstreamer=1.14.0=hb31296c_0
- icu=58.2=hf484d3e_1000
- intel-openmp=2020.1=217
- joblib=0.15.1=py_0
- jpeg=9b=h024ee3a_2
- kiwisolver=1.2.0=py38hbf85e49_0
- lame=3.100=h14c3975_1001
- ld_impl_linux-64=2.33.1=h53a641e_7
- libedit=3.1.20181209=hc058e9b_0
- libffi=3.3=he6710b0_1
- libflac=1.3.3=he1b5a44_0
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libiconv=1.15=h516909a_1006
- libogg=1.3.2=h516909a_1002
- libpng=1.6.37=hbc83047_0
- librosa=0.7.2=py_1
- libsndfile=1.0.28=he1b5a44_1000
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h2733197_0
- libuuid=1.0.3=h1bed415_2
- libvorbis=1.3.6=he1b5a44_2
- libxcb=1.13=h1bed415_1
- libxml2=2.9.9=hea5a465_1
- llvmlite=0.32.1=py38hd408876_0
- matplotlib=3.1.3=py38_0
- matplotlib-base=3.1.3=py38hef1b27d_0
- mkl=2020.1=217
- mkl-service=2.3.0=py38he904b0f_0
- mkl_fft=1.0.15=py38ha843d7b_0
- mkl_random=1.1.1=py38h0573a6f_0
- ncurses=6.2=he6710b0_1
- nettle=3.4.1=h1bed415_1002
- ninja=1.9.0=py38hfd86e86_0
- numba=0.49.1=py38h0573a6f_0
- numpy=1.18.1=py38h4f9e942_0
- numpy-base=1.18.1=py38hde5b4d6_1
- olefile=0.46=py_0
- openh264=2.1.1=h8b12597_0
- openssl=1.1.1g=h7b6447c_0
- pandas=1.0.3=py38h0573a6f_0
- pcre=8.43=he6710b0_0
- pillow=7.1.2=py38hb39fc2d_0
- pip=20.0.2=py38_3
- pycparser=2.20=py_0
- pyparsing=2.4.7=pyh9f0ad1d_0
- pyqt=5.9.2=py38h05f1152_4
- pysoundfile=0.10.2=py_1001
- python=3.8.3=hcff3b4d_0
- python-dateutil=2.8.1=py_0
- python_abi=3.8=1_cp38
- pytz=2020.1=py_0
- qt=5.9.7=h5867ecd_1
- readline=8.0=h7b6447c_0
- resampy=0.2.2=py_0
- scikit-learn=0.22.1=py38hd81dba3_0
- scipy=1.4.1=py38h0b6359f_0
- seaborn=0.10.1=py_0
- setuptools=46.4.0=py38_0
- sip=4.19.13=py38he6710b0_0
- six=1.14.0=py38_0
- sqlite=3.31.1=h62c20be_1
- tbb=2020.1=hc9558a2_0
- tk=8.6.8=hbc83047_0
- torchvision=0.6.0=py38_cu102
- tornado=6.0.4=py38h1e0a361_1
- tqdm=4.46.0=py_0
- wheel=0.34.2=py38_0
- x264=1!152.20180806=h14c3975_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.3.7=h0b5b093_0
- pip:
- absl-py==0.9.0
- blessings==1.7
- cachetools==4.1.0
- chardet==3.0.4
- future==0.18.2
- google-auth==1.16.0
- google-auth-oauthlib==0.4.1
- gpustat==0.6.0
- grpcio==1.29.0
- idna==2.9
- madmom==0.16.1
- markdown==3.2.2
- mido==1.2.9
- nvidia-ml-py3==7.352.0
- oauthlib==3.1.0
- protobuf==3.12.2
- psutil==5.7.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- requests==2.23.0
- requests-oauthlib==1.3.0
- rsa==4.0
- tensorboard==2.2.2
- tensorboard-plugin-wit==1.6.0.post3
- torch==1.5.1
- torchaudio==0.5.1
- torchcontrib==0.0.2
- urllib3==1.25.9
- werkzeug==1.0.1
import numpy as np
import torch
import torch.utils.data
def flip_ver(x):
return x.flip(1)
def flip_hor(x):
return x.flip(2)
class DsetSSFlipRand(torch.utils.data.Dataset):
def __init__(self, dset):
self.dset = dset
def __getitem__(self, index):
path, x = self.dset[index][:2]
label = np.random.randint(2)
if label == 0:
x = x
else:
x = x.flip(1)
return path, x, label
def __len__(self):
return len(self.dset)
class DsetSSMixup(torch.utils.data.Dataset):
def __init__(self, dset):
self.dset = dset
def __getitem__(self, index):
path, x = self.dset[index][:2]
rand_idx = np.random.randint(len(self.dset))
if rand_idx == index: rand_idx += 1
path1, x1 = self.dset[rand_idx][:2]
label = np.random.randint(2)
if label == 0:
x = x
else:
split_pt = x.shape[1] // 2 # + np.random.randint(-100,100)
x = torch.cat([x[:, :split_pt], x1[:, split_pt:]], 1)
return path, x, label
def __len__(self):
return len(self.dset)
class DsetSSTimeSwap(torch.utils.data.Dataset):
def __init__(self, dset):
self.dset = dset
def __getitem__(self, index):
path, x = self.dset[index][:2]
label = np.random.randint(2)
if label == 0:
x = x
else:
lenx = x.shape[1]
x = torch.cat([x[:, lenx // 2:], x[:, :lenx // 2]], 1)
return path, x, label
def __len__(self):
return len(self.dset)
def tensor_rot_90(x):
return x.flip(2).transpose(1, 2)
def tensor_rot_90_digit(x):
return x.transpose(1, 2)
def tensor_rot_180(x):
return x.flip(2).flip(1)
def tensor_rot_180_digit(x):
return x.flip(2)
def tensor_rot_270(x):
return x.transpose(1, 2).flip(2)
class DsetSSRotRand(torch.utils.data.Dataset):
def __init__(self, dset, digit=True):
self.dset = dset
self.digit = digit
def __getitem__(self, index):
image = self.dset[index][0]
label = np.random.randint(4)
if label == 1:
if self.digit:
image = tensor_rot_90_digit(image)
else:
image = tensor_rot_90(image)
elif label == 2:
if self.digit:
image = tensor_rot_180_digit(image)
else:
image = tensor_rot_180(image)
elif label == 3:
image = tensor_rot_270(image)
return image, label
def __len__(self):
return len(self.dset)
import hashlib
import os
from tqdm import tqdm
from helpers.audio import MadmomAudioProcessor
from paths import *
from utils import *
import torch
import torch.utils.data
import numpy as np
from sklearn.model_selection import train_test_split
dset_stats = {}
def get_dataset_stats(dset_name, aud_processor=None):
global dset_stats
if aud_processor is None:
aud_processor = MadmomAudioProcessor(fps=31.3)
dataset_cache_dir = os.path.join(path_cache_fs, dset_name)
all_specs_path = os.path.join(dataset_cache_dir, aud_processor.get_params.get("name"))
if dset_stats.get(dset_name) is not None:
try:
mean = dset_stats[dset_name]['mean']
std = dset_stats[dset_name]['std']
except KeyError:
raise Exception(f"mean or std not found for {dset_name}")
else:
try:
mean = np.load(os.path.join(all_specs_path, 'mean.npy'))
except FileNotFoundError:
all_specs_list = list_files_deep(all_specs_path, full_paths=True, filter_ext=['.specobj'])
mean = 0.0
for specobj in tqdm(all_specs_list, desc=f'Calculating mean for {dset_name} {aud_processor.get_params.get("name")}'):
spec = pickleload(specobj).spec
mean += np.mean(spec).item()
mean = mean / len(all_specs_list)
np.save(os.path.join(all_specs_path, 'mean.npy'), mean)
try:
std = np.load(os.path.join(all_specs_path, 'std.npy'))
except FileNotFoundError:
all_specs_list = list_files_deep(all_specs_path, full_paths=True, filter_ext=['.specobj'])
sum_of_mean_of_squared_dev = 0.0
for specobj in tqdm(all_specs_list, desc=f'Calculating std for {dset_name} {aud_processor.get_params.get("name")}'):
spec = pickleload(specobj).spec
sum_of_mean_of_squared_dev += np.mean(np.square(spec - mean)).item()
std = np.sqrt(sum_of_mean_of_squared_dev / len(all_specs_list))
np.save(os.path.join(all_specs_path, 'std.npy'), std)
dset_stats[dset_name] = {'mean': mean, 'std': std}
return mean, std
def normalize_spec(spec, mean=None, std=None, dset_name=None, aud_processor=None):
if mean is None and std is None:
mean, std = get_dataset_stats(dset_name, aud_processor)
assert (isinstance(mean, np.ndarray) and isinstance(std, np.ndarray)) or (isinstance(mean, float) and isinstance(std, float)), \
print(f"Either mean or std is not a float: mean={mean}, std={std}")
return (spec - mean) / std
def slice_func(spec, length, processor=None, mode='random', offset_seconds=0, slice_times=None):
if slice_times is not None:
start_time, end_time = slice_times[0], slice_times[1]
return spec[:, processor.times_to_frames(start_time): processor.times_to_frames(end_time)], start_time, end_time
offset_frames = int(processor.times_to_frames(offset_seconds))
length = int(length)
while spec.shape[-1] < offset_frames + length:
spec = np.append(spec, spec[:, :length - spec.shape[-1]], axis=1)
xlen = spec.shape[-1]
midpoint = xlen // 2 + offset_frames
if mode == 'start':
start_time = processor.frames_to_times(offset_frames)
end_time = processor.frames_to_times(offset_frames + length)
output = spec[:, offset_frames: offset_frames + length]
elif mode == 'end':
start_time = processor.frames_to_times(xlen - length)
end_time = processor.frames_to_times(xlen)
output = spec[:, -length:]
elif mode == 'middle':
start_time = processor.frames_to_times(xlen - length)
end_time = processor.frames_to_times(xlen)
output = spec[:, midpoint - length // 2: midpoint + length // 2 + 1]
elif mode == 'random':
k = torch.randint(offset_frames, xlen - length + 1, (1,))[0].item()
start_time = processor.frames_to_times(k)
end_time = processor.frames_to_times(k + length)
output = spec[:, k: k + length]
else:
raise Exception(f"mode must be in ['start', 'end', 'middle', 'random'], is {mode}")
return output, start_time, end_time
class DsetNoLabel(torch.utils.data.Dataset):
# Make sure that your dataset actually returns many elements!
def __init__(self, dset):
self.dset = dset
def __getitem__(self, index):
ret_stuff = self.dset[index]
return ret_stuff[:-1] if len(ret_stuff) > 2 else ret_stuff
def __len__(self):
return len(self.dset)
class DsetMultiDataSources(torch.utils.data.Dataset):
def __init__(self, *dsets):
self.dsets = dsets
self.lengths = [len(d) for d in self.dsets]
def __getitem__(self, index):
return_triplets = []
for ds in self.dsets:
idx = index % len(ds)
try:
path, x, y = ds[idx]
return_triplets.append((path, x, y))
except:
# if dataset does not return path, generate a (semi-)unique hash from the sum of the tensor, to be used as an identifier of the tensor for caching
x, y = ds[idx]
return_triplets.append((hashlib.md5(f"{str(torch.sum(x).item())}".encode("UTF-8")).hexdigest(), x, y))
return tuple(return_triplets)
def __len__(self):
return min(self.lengths)
class DsetThreeChannels(torch.utils.data.Dataset):
def __init__(self, dset):
self.dset = dset
def __getitem__(self, index):
image, label = self.dset[index]
return image.repeat(3, 1, 1), label
def __len__(self):
return len(self.dset)