Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shreyan Chowdhury
midlevel_da
Commits
9f0e4820
Commit
9f0e4820
authored
Oct 23, 2020
by
Shreyan Chowdhury
Browse files
add project files
parent
e3345094
Changes
22
Hide whitespace changes
Inline
Side-by-side
.gitignore
0 → 100644
View file @
9f0e4820
*.pyc
*.idea*
conda_env_cpu.yml
0 → 100644
View file @
9f0e4820
name
:
midlevel_da
channels
:
-
conda-forge
-
defaults
dependencies
:
-
_libgcc_mutex=0.1=conda_forge
-
_openmp_mutex=4.5=1_llvm
-
appdirs=1.4.3=py_1
-
argon2-cffi=20.1.0=py38h1e0a361_1
-
async_generator=1.10=py_0
-
attrs=20.2.0=pyh9f0ad1d_0
-
audioread=2.1.8=py38h32f6830_2
-
backcall=0.2.0=pyh9f0ad1d_0
-
backports=1.0=py_2
-
backports.functools_lru_cache=1.6.1=py_0
-
blas=1.0=mkl
-
bleach=3.2.1=pyh9f0ad1d_0
-
brotlipy=0.7.0=py38h1e0a361_1000
-
bzip2=1.0.8=h516909a_2
-
ca-certificates=2020.6.20=hecda079_0
-
certifi=2020.6.20=py38h32f6830_0
-
cffi=1.14.1=py38h5bae8af_0
-
chardet=3.0.4=py38h32f6830_1006
-
cryptography=3.0=py38h766eaa4_0
-
cudatoolkit=10.2.89=hfd86e86_1
-
cycler=0.10.0=py_2
-
cython=0.29.21=py38h950e882_0
-
dbus=1.13.6=he372182_0
-
decorator=4.4.2=py_0
-
defusedxml=0.6.0=py_0
-
entrypoints=0.3=py38h32f6830_1001
-
expat=2.2.9=he1b5a44_2
-
ffmpeg=4.3.1=h167e202_0
-
fontconfig=2.13.1=h1056068_1002
-
freetype=2.10.2=he06d7ca_0
-
gettext=0.19.8.1=hc5be6a0_1002
-
glib=2.58.3=py38h73cb85d_1004
-
gmp=6.2.0=he1b5a44_2
-
gnutls=3.6.13=h79a8f9a_0
-
gst-plugins-base=1.14.5=h0935bb2_2
-
gstreamer=1.14.5=h36ae1b5_2
-
icu=67.1=he1b5a44_0
-
importlib-metadata=2.0.0=py38h32f6830_0
-
importlib_metadata=2.0.0=0
-
intel-openmp=2020.1=217
-
ipykernel=5.3.4=py38h23f93f0_0
-
ipython=7.17.0=py38h1cdfbd6_0
-
ipython_genutils=0.2.0=py_1
-
ipywidgets=7.5.1=pyh9f0ad1d_1
-
jedi=0.17.2=py38h32f6830_0
-
jinja2=2.11.2=pyh9f0ad1d_0
-
joblib=0.16.0=py_0
-
jpeg=9d=h516909a_0
-
jsonschema=3.2.0=py38h32f6830_1
-
jupyter=1.0.0=py_2
-
jupyter_client=6.1.6=py_0
-
jupyter_console=6.2.0=py_0
-
jupyter_core=4.6.3=py38h32f6830_1
-
jupyterlab_pygments=0.1.1=pyh9f0ad1d_0
-
kiwisolver=1.2.0=py38hbf85e49_0
-
krb5=1.17.1=hfafb76e_3
-
lame=3.100=h14c3975_1001
-
lcms2=2.11=hbd6801e_0
-
ld_impl_linux-64=2.34=hc38a660_9
-
libblas=3.8.0=16_mkl
-
libcblas=3.8.0=16_mkl
-
libclang=10.0.1=default_hde54327_1
-
libedit=3.1.20191231=h46ee950_1
-
libevent=2.1.10=hcdb4288_2
-
libffi=3.2.1=he1b5a44_1007
-
libflac=1.3.3=he1b5a44_0
-
libgcc-ng=9.3.0=h24d8f2e_14
-
libgfortran-ng=7.5.0=hdf63c60_14
-
libiconv=1.15=h516909a_1006
-
liblapack=3.8.0=16_mkl
-
libllvm10=10.0.1=he513fc3_3
-
libllvm9=9.0.1=he513fc3_1
-
libogg=1.3.2=h516909a_1002
-
libpng=1.6.37=hed695b0_1
-
libpq=12.3=h5513abc_0
-
librosa=0.8.0=pyh9f0ad1d_0
-
libsndfile=1.0.28=he1b5a44_1000
-
libsodium=1.0.18=h516909a_0
-
libstdcxx-ng=9.3.0=hdf63c60_14
-
libtiff=4.1.0=hc7e4089_6
-
libuuid=2.32.1=h14c3975_1000
-
libvorbis=1.3.7=he1b5a44_0
-
libwebp-base=1.1.0=h516909a_3
-
libxcb=1.13=h14c3975_1002
-
libxkbcommon=0.10.0=he1b5a44_0
-
libxml2=2.9.10=h72b56ed_2
-
llvm-openmp=10.0.1=hc9558a2_0
-
llvmlite=0.33.0=py38h4f45e52_1
-
lz4-c=1.9.2=he1b5a44_1
-
markupsafe=1.1.1=py38h1e0a361_1
-
matplotlib-base=3.3.0=py38h91b0d89_1
-
mistune=0.8.4=py38h1e0a361_1001
-
mkl=2020.2=256
-
mkl-service=2.3.0=py38he904b0f_0
-
mkl_fft=1.1.0=py38h23d657b_0
-
mkl_random=1.1.1=py38hcb8c335_0
-
mysql-common=8.0.21=2
-
mysql-libs=8.0.21=hf3661c5_2
-
nbclient=0.5.0=py_0
-
nbconvert=6.0.6=py38h32f6830_0
-
nbformat=5.0.7=py_0
-
ncurses=6.2=he1b5a44_1
-
nest-asyncio=1.4.0=py_1
-
nettle=3.4.1=h1bed415_1002
-
ninja=1.10.0=hc9558a2_0
-
notebook=6.1.4=py38h32f6830_0
-
nspr=4.29=he1b5a44_0
-
nss=3.57=he751ad9_0
-
numba=0.50.1=py38hcb8c335_1
-
numpy=1.19.1=py38h8854b6b_0
-
numpy-base=1.18.5=py38hde5b4d6_0
-
olefile=0.46=py_0
-
openh264=2.1.1=h8b12597_0
-
openssl=1.1.1h=h516909a_0
-
packaging=20.4=pyh9f0ad1d_0
-
pandas=1.1.0=py38h950e882_0
-
pandoc=2.10.1=h516909a_0
-
pandocfilters=1.4.2=py_1
-
parso=0.7.1=pyh9f0ad1d_0
-
patsy=0.5.1=py_0
-
pcre=8.44=he1b5a44_0
-
pexpect=4.8.0=py38h32f6830_1
-
pickleshare=0.7.5=py38h32f6830_1001
-
pillow=7.2.0=py38h9776b28_1
-
pip=20.2.1=py_0
-
pooch=1.1.1=py_0
-
prometheus_client=0.8.0=pyh9f0ad1d_0
-
prompt-toolkit=3.0.7=py_0
-
prompt_toolkit=3.0.7=0
-
pthread-stubs=0.4=h14c3975_1001
-
ptyprocess=0.6.0=py_1001
-
pycparser=2.20=pyh9f0ad1d_2
-
pygments=2.6.1=py_0
-
pyopenssl=19.1.0=py_1
-
pyparsing=2.4.7=pyh9f0ad1d_0
-
pyqt=5.12.3=py38ha8c2ead_3
-
pyrsistent=0.17.3=py38h1e0a361_0
-
pysocks=1.7.1=py38h32f6830_1
-
pysoundfile=0.10.2=py_1001
-
python=3.8.5=h4d41432_2_cpython
-
python-dateutil=2.8.1=py_0
-
python_abi=3.8=1_cp38
-
pytz=2020.1=pyh9f0ad1d_0
-
pyzmq=19.0.2=py38ha71036d_0
-
qt=5.12.9=h1f2b2cb_0
-
qtconsole=4.7.7=pyh9f0ad1d_0
-
qtpy=1.9.0=py_0
-
readline=8.0=he28a2e2_2
-
resampy=0.2.2=py_0
-
scikit-learn=0.23.2=py38hee58b96_0
-
scipy=1.5.2=py38h8c5af15_0
-
seaborn=0.10.1=1
-
seaborn-base=0.10.1=py_1
-
send2trash=1.5.0=py_0
-
setuptools=49.3.1=py38h32f6830_0
-
six=1.15.0=pyh9f0ad1d_0
-
sqlite=3.33.0=h4cf870e_0
-
statsmodels=0.11.1=py38h1e0a361_2
-
tbb=2020.1=hc9558a2_0
-
terminado=0.9.1=py38h32f6830_0
-
testpath=0.4.4=py_0
-
threadpoolctl=2.1.0=pyh5ca1d4c_0
-
tk=8.6.10=hed695b0_0
-
tornado=6.0.4=py38h1e0a361_1
-
tqdm=4.48.2=pyh9f0ad1d_0
-
traitlets=4.3.3=py38h32f6830_1
-
wcwidth=0.2.5=pyh9f0ad1d_1
-
webencodings=0.5.1=py_1
-
wheel=0.34.2=py_1
-
widgetsnbextension=3.5.1=py38h32f6830_1
-
x264=1!152.20180806=h14c3975_0
-
xorg-libxau=1.0.9=h14c3975_0
-
xorg-libxdmcp=1.1.3=h516909a_0
-
xz=5.2.5=h516909a_1
-
zeromq=4.3.2=he1b5a44_3
-
zipp=3.2.0=py_0
-
zlib=1.2.11=h516909a_1007
-
zstd=1.4.5=h6597ccf_2
-
pip
:
-
absl-py==0.9.0
-
cachetools==4.1.0
-
future==0.18.2
-
google-auth==1.16.0
-
google-auth-oauthlib==0.4.1
-
grpcio==1.29.0
-
idna==2.9
-
madmom==0.16.1
-
markdown==3.2.2
-
mido==1.2.9
-
oauthlib==3.1.0
-
protobuf==3.12.2
-
pyasn1==0.4.8
-
pyasn1-modules==0.2.8
-
pyqt5-sip==4.19.18
-
pyqtchart==5.12
-
pyqtwebengine==5.12.1
-
requests==2.23.0
-
requests-oauthlib==1.3.0
-
rsa==4.0
-
tensorboard==2.2.2
-
tensorboard-plugin-wit==1.6.0.post3
-
torch==1.6.0
-
torchaudio==0.5.1
-
torchcontrib==0.0.2
-
torchsummary==1.5.1
-
torchvision==0.7.0
-
urllib3==1.25.9
-
werkzeug==1.0.1
conda_env_gpu.yml
0 → 100644
View file @
9f0e4820
name
:
midlevel_da
channels
:
-
pytorch
-
conda-forge
-
defaults
dependencies
:
-
_libgcc_mutex=0.1=main
-
audioread=2.1.8=py38h32f6830_2
-
blas=1.0=mkl
-
bzip2=1.0.8=h516909a_2
-
ca-certificates=2020.1.1=0
-
certifi=2020.4.5.1=py38_0
-
cffi=1.14.0=py38he30daa8_1
-
cudatoolkit=10.2.89=hfd86e86_1
-
cycler=0.10.0=py_2
-
cython=0.29.17=py38he6710b0_0
-
dbus=1.13.14=hb2f20db_0
-
decorator=4.4.2=py_0
-
expat=2.2.6=he6710b0_0
-
ffmpeg=4.2.3=h167e202_0
-
fontconfig=2.13.0=h9420a91_0
-
freetype=2.9.1=h8a8886c_1
-
gettext=0.19.8.1=h5e8e0c9_1
-
glib=2.63.1=h3eb4bd4_1
-
gmp=6.2.0=he1b5a44_2
-
gnutls=3.6.13=h79a8f9a_0
-
gst-plugins-base=1.14.0=hbbd80ab_1
-
gstreamer=1.14.0=hb31296c_0
-
icu=58.2=hf484d3e_1000
-
intel-openmp=2020.1=217
-
joblib=0.15.1=py_0
-
jpeg=9b=h024ee3a_2
-
kiwisolver=1.2.0=py38hbf85e49_0
-
lame=3.100=h14c3975_1001
-
ld_impl_linux-64=2.33.1=h53a641e_7
-
libedit=3.1.20181209=hc058e9b_0
-
libffi=3.3=he6710b0_1
-
libflac=1.3.3=he1b5a44_0
-
libgcc-ng=9.1.0=hdf63c60_0
-
libgfortran-ng=7.3.0=hdf63c60_0
-
libiconv=1.15=h516909a_1006
-
libogg=1.3.2=h516909a_1002
-
libpng=1.6.37=hbc83047_0
-
librosa=0.7.2=py_1
-
libsndfile=1.0.28=he1b5a44_1000
-
libstdcxx-ng=9.1.0=hdf63c60_0
-
libtiff=4.1.0=h2733197_0
-
libuuid=1.0.3=h1bed415_2
-
libvorbis=1.3.6=he1b5a44_2
-
libxcb=1.13=h1bed415_1
-
libxml2=2.9.9=hea5a465_1
-
llvmlite=0.32.1=py38hd408876_0
-
matplotlib=3.1.3=py38_0
-
matplotlib-base=3.1.3=py38hef1b27d_0
-
mkl=2020.1=217
-
mkl-service=2.3.0=py38he904b0f_0
-
mkl_fft=1.0.15=py38ha843d7b_0
-
mkl_random=1.1.1=py38h0573a6f_0
-
ncurses=6.2=he6710b0_1
-
nettle=3.4.1=h1bed415_1002
-
ninja=1.9.0=py38hfd86e86_0
-
numba=0.49.1=py38h0573a6f_0
-
numpy=1.18.1=py38h4f9e942_0
-
numpy-base=1.18.1=py38hde5b4d6_1
-
olefile=0.46=py_0
-
openh264=2.1.1=h8b12597_0
-
openssl=1.1.1g=h7b6447c_0
-
pandas=1.0.3=py38h0573a6f_0
-
pcre=8.43=he6710b0_0
-
pillow=7.1.2=py38hb39fc2d_0
-
pip=20.0.2=py38_3
-
pycparser=2.20=py_0
-
pyparsing=2.4.7=pyh9f0ad1d_0
-
pyqt=5.9.2=py38h05f1152_4
-
pysoundfile=0.10.2=py_1001
-
python=3.8.3=hcff3b4d_0
-
python-dateutil=2.8.1=py_0
-
python_abi=3.8=1_cp38
-
pytz=2020.1=py_0
-
qt=5.9.7=h5867ecd_1
-
readline=8.0=h7b6447c_0
-
resampy=0.2.2=py_0
-
scikit-learn=0.22.1=py38hd81dba3_0
-
scipy=1.4.1=py38h0b6359f_0
-
seaborn=0.10.1=py_0
-
setuptools=46.4.0=py38_0
-
sip=4.19.13=py38he6710b0_0
-
six=1.14.0=py38_0
-
sqlite=3.31.1=h62c20be_1
-
tbb=2020.1=hc9558a2_0
-
tk=8.6.8=hbc83047_0
-
torchvision=0.6.0=py38_cu102
-
tornado=6.0.4=py38h1e0a361_1
-
tqdm=4.46.0=py_0
-
wheel=0.34.2=py38_0
-
x264=1!152.20180806=h14c3975_0
-
xz=5.2.5=h7b6447c_0
-
zlib=1.2.11=h7b6447c_3
-
zstd=1.3.7=h0b5b093_0
-
pip
:
-
absl-py==0.9.0
-
blessings==1.7
-
cachetools==4.1.0
-
chardet==3.0.4
-
future==0.18.2
-
google-auth==1.16.0
-
google-auth-oauthlib==0.4.1
-
gpustat==0.6.0
-
grpcio==1.29.0
-
idna==2.9
-
madmom==0.16.1
-
markdown==3.2.2
-
mido==1.2.9
-
nvidia-ml-py3==7.352.0
-
oauthlib==3.1.0
-
protobuf==3.12.2
-
psutil==5.7.0
-
pyasn1==0.4.8
-
pyasn1-modules==0.2.8
-
requests==2.23.0
-
requests-oauthlib==1.3.0
-
rsa==4.0
-
tensorboard==2.2.2
-
tensorboard-plugin-wit==1.6.0.post3
-
torch==1.5.1
-
torchaudio==0.5.1
-
torchcontrib==0.0.2
-
urllib3==1.25.9
-
werkzeug==1.0.1
midlevel_da/datasets/__init__.py
0 → 100644
View file @
9f0e4820
midlevel_da/datasets/dataset_augment.py
0 → 100644
View file @
9f0e4820
import
numpy
as
np
import
torch
import
torch.utils.data
def
flip_ver
(
x
):
return
x
.
flip
(
1
)
def
flip_hor
(
x
):
return
x
.
flip
(
2
)
class
DsetSSFlipRand
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
dset
):
self
.
dset
=
dset
def
__getitem__
(
self
,
index
):
path
,
x
=
self
.
dset
[
index
][:
2
]
label
=
np
.
random
.
randint
(
2
)
if
label
==
0
:
x
=
x
else
:
x
=
x
.
flip
(
1
)
return
path
,
x
,
label
def
__len__
(
self
):
return
len
(
self
.
dset
)
class
DsetSSMixup
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
dset
):
self
.
dset
=
dset
def
__getitem__
(
self
,
index
):
path
,
x
=
self
.
dset
[
index
][:
2
]
rand_idx
=
np
.
random
.
randint
(
len
(
self
.
dset
))
if
rand_idx
==
index
:
rand_idx
+=
1
path1
,
x1
=
self
.
dset
[
rand_idx
][:
2
]
label
=
np
.
random
.
randint
(
2
)
if
label
==
0
:
x
=
x
else
:
split_pt
=
x
.
shape
[
1
]
//
2
# + np.random.randint(-100,100)
x
=
torch
.
cat
([
x
[:,
:
split_pt
],
x1
[:,
split_pt
:]],
1
)
return
path
,
x
,
label
def
__len__
(
self
):
return
len
(
self
.
dset
)
class
DsetSSTimeSwap
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
dset
):
self
.
dset
=
dset
def
__getitem__
(
self
,
index
):
path
,
x
=
self
.
dset
[
index
][:
2
]
label
=
np
.
random
.
randint
(
2
)
if
label
==
0
:
x
=
x
else
:
lenx
=
x
.
shape
[
1
]
x
=
torch
.
cat
([
x
[:,
lenx
//
2
:],
x
[:,
:
lenx
//
2
]],
1
)
return
path
,
x
,
label
def
__len__
(
self
):
return
len
(
self
.
dset
)
def
tensor_rot_90
(
x
):
return
x
.
flip
(
2
).
transpose
(
1
,
2
)
def
tensor_rot_90_digit
(
x
):
return
x
.
transpose
(
1
,
2
)
def
tensor_rot_180
(
x
):
return
x
.
flip
(
2
).
flip
(
1
)
def
tensor_rot_180_digit
(
x
):
return
x
.
flip
(
2
)
def
tensor_rot_270
(
x
):
return
x
.
transpose
(
1
,
2
).
flip
(
2
)
class
DsetSSRotRand
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
dset
,
digit
=
True
):
self
.
dset
=
dset
self
.
digit
=
digit
def
__getitem__
(
self
,
index
):
image
=
self
.
dset
[
index
][
0
]
label
=
np
.
random
.
randint
(
4
)
if
label
==
1
:
if
self
.
digit
:
image
=
tensor_rot_90_digit
(
image
)
else
:
image
=
tensor_rot_90
(
image
)
elif
label
==
2
:
if
self
.
digit
:
image
=
tensor_rot_180_digit
(
image
)
else
:
image
=
tensor_rot_180
(
image
)
elif
label
==
3
:
image
=
tensor_rot_270
(
image
)
return
image
,
label
def
__len__
(
self
):
return
len
(
self
.
dset
)
midlevel_da/datasets/dataset_utils.py
0 → 100644
View file @
9f0e4820
import
hashlib
import
os
from
tqdm
import
tqdm
from
helpers.audio
import
MadmomAudioProcessor
from
paths
import
*
from
utils
import
*
import
torch
import
torch.utils.data
import
numpy
as
np
from
sklearn.model_selection
import
train_test_split
dset_stats
=
{}
def
get_dataset_stats
(
dset_name
,
aud_processor
=
None
):
global
dset_stats
if
aud_processor
is
None
:
aud_processor
=
MadmomAudioProcessor
(
fps
=
31.3
)
dataset_cache_dir
=
os
.
path
.
join
(
path_cache_fs
,
dset_name
)
all_specs_path
=
os
.
path
.
join
(
dataset_cache_dir
,
aud_processor
.
get_params
.
get
(
"name"
))
if
dset_stats
.
get
(
dset_name
)
is
not
None
:
try
:
mean
=
dset_stats
[
dset_name
][
'mean'
]
std
=
dset_stats
[
dset_name
][
'std'
]
except
KeyError
:
raise
Exception
(
f
"mean or std not found for
{
dset_name
}
"
)
else
:
try
:
mean
=
np
.
load
(
os
.
path
.
join
(
all_specs_path
,
'mean.npy'
))
except
FileNotFoundError
:
all_specs_list
=
list_files_deep
(
all_specs_path
,
full_paths
=
True
,
filter_ext
=
[
'.specobj'
])
mean
=
0.0
for
specobj
in
tqdm
(
all_specs_list
,
desc
=
f
'Calculating mean for
{
dset_name
}
{
aud_processor
.
get_params
.
get
(
"name"
)
}
'
):
spec
=
pickleload
(
specobj
).
spec
mean
+=
np
.
mean
(
spec
).
item
()
mean
=
mean
/
len
(
all_specs_list
)
np
.
save
(
os
.
path
.
join
(
all_specs_path
,
'mean.npy'
),
mean
)
try
:
std
=
np
.
load
(
os
.
path
.
join
(
all_specs_path
,
'std.npy'
))
except
FileNotFoundError
:
all_specs_list
=
list_files_deep
(
all_specs_path
,
full_paths
=
True
,
filter_ext
=
[
'.specobj'
])
sum_of_mean_of_squared_dev
=
0.0
for
specobj
in
tqdm
(
all_specs_list
,
desc
=
f
'Calculating std for
{
dset_name
}
{
aud_processor
.
get_params
.
get
(
"name"
)
}
'
):
spec
=
pickleload
(
specobj
).
spec
sum_of_mean_of_squared_dev
+=
np
.
mean
(
np
.
square
(
spec
-
mean
)).
item
()
std
=
np
.
sqrt
(
sum_of_mean_of_squared_dev
/
len
(
all_specs_list
))
np
.
save
(
os
.
path
.
join
(
all_specs_path
,
'std.npy'
),
std
)
dset_stats
[
dset_name
]
=
{
'mean'
:
mean
,
'std'
:
std
}
return
mean
,
std
def
normalize_spec
(
spec
,
mean
=
None
,
std
=
None
,
dset_name
=
None
,
aud_processor
=
None
):
if
mean
is
None
and
std
is
None
:
mean
,
std
=
get_dataset_stats
(
dset_name
,
aud_processor
)
assert
(
isinstance
(
mean
,
np
.
ndarray
)
and
isinstance
(
std
,
np
.
ndarray
))
or
(
isinstance
(
mean
,
float
)
and
isinstance
(
std
,
float
)),
\
print
(
f
"Either mean or std is not a float: mean=
{
mean
}
, std=
{
std
}
"
)
return
(
spec
-
mean
)
/
std
def
slice_func
(
spec
,
length
,
processor
=
None
,
mode
=
'random'
,
offset_seconds
=
0
,
slice_times
=
None
):
if
slice_times
is
not
None
:
start_time
,
end_time
=
slice_times
[
0
],
slice_times
[
1
]
return
spec
[:,
processor
.
times_to_frames
(
start_time
):
processor
.
times_to_frames
(
end_time
)],
start_time
,
end_time
offset_frames
=
int
(
processor
.
times_to_frames
(
offset_seconds
))
length
=
int
(
length
)
while
spec
.
shape
[
-
1
]
<
offset_frames
+
length
:
spec
=
np
.
append
(
spec
,
spec
[:,
:
length
-
spec
.
shape
[
-
1
]],
axis
=
1
)
xlen
=
spec
.
shape
[
-
1
]
midpoint
=
xlen
//
2
+
offset_frames
if
mode
==
'start'
:
start_time
=
processor
.
frames_to_times
(
offset_frames
)
end_time
=
processor
.
frames_to_times
(
offset_frames
+
length
)
output
=
spec
[:,
offset_frames
:
offset_frames
+
length
]
elif
mode
==
'end'
:
start_time
=
processor
.
frames_to_times
(
xlen
-
length
)
end_time
=
processor
.
frames_to_times
(
xlen
)
output
=
spec
[:,
-
length
:]
elif
mode
==
'middle'
:
start_time
=
processor
.
frames_to_times
(
xlen
-
length
)
end_time
=
processor
.
frames_to_times
(
xlen
)
output
=
spec
[:,
midpoint
-
length
//
2
:
midpoint
+
length
//
2
+
1
]
elif
mode
==
'random'
:
k
=
torch
.
randint
(
offset_frames
,
xlen
-
length
+
1
,
(
1
,))[
0
].
item
()
start_time
=
processor
.
frames_to_times
(
k
)
end_time
=
processor
.
frames_to_times
(
k
+
length
)
output
=
spec
[:,
k
:
k
+
length
]
else
:
raise
Exception
(
f
"mode must be in ['start', 'end', 'middle', 'random'], is
{
mode
}
"
)
return
output
,
start_time
,
end_time
class
DsetNoLabel
(
torch
.
utils
.
data
.
Dataset
):
# Make sure that your dataset actually returns many elements!
def
__init__
(
self
,
dset
):
self
.
dset
=
dset
def
__getitem__
(
self
,
index
):
ret_stuff
=
self
.
dset
[
index
]
return
ret_stuff
[:
-
1
]
if
len
(
ret_stuff
)
>
2
else
ret_stuff
def
__len__
(
self
):
return
len
(
self
.
dset
)
class
DsetMultiDataSources
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
*
dsets
):
self
.
dsets
=
dsets
self
.
lengths
=
[
len
(
d
)
for
d
in
self
.
dsets
]
def
__getitem__
(
self
,
index
):
return_triplets
=
[]
for
ds
in
self
.
dsets
:
idx
=
index
%
len
(
ds
)
try
:
path
,
x
,
y
=
ds
[
idx
]
return_triplets
.
append
((
path
,
x
,
y
))
except
:
# if dataset does not return path, generate a (semi-)unique hash from the sum of the tensor, to be used as an identifier of the tensor for caching
x
,
y
=
ds
[
idx
]
return_triplets
.
append
((
hashlib
.
md5
(
f
"
{
str
(
torch
.
sum
(
x
).
item
())
}
"
.
encode
(
"UTF-8"
)).
hexdigest
(),
x
,
y
))
return
tuple
(
return_triplets
)
def
__len__
(
self
):
return
min
(
self
.
lengths
)
class
DsetThreeChannels
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
dset
):
self
.
dset
=
dset
def
__getitem__
(
self
,
index
):
image
,
label
=
self
.
dset
[
index
]
return
image
.
repeat
(
3
,
1
,
1
),
label
def
__len__
(
self
):
return
len
(
self
.
dset
)
if
__name__
==
'__main__'
:
print
(
get_dataset_stats
(
'midlevel'
))
midlevel_da/datasets/extra_datasets.py
0 → 100644
View file @
9f0e4820
from
sklearn.model_selection
import
train_test_split