Commit 3b06d542 authored by Paul Primus's avatar Paul Primus
Browse files

add conv-autoencoder architecture

parent c64fb7de
......@@ -55,7 +55,8 @@ class MCMDataSet(BaseDataSet):
power=1.0,
fmin=0,
normalize_raw=False,
normalize=None
normalize=None,
hop_all=False
):
self.data_root = data_root
self.context = context
......@@ -65,6 +66,7 @@ class MCMDataSet(BaseDataSet):
self.power = power
self.fmin = fmin
self.normalize = normalize
self.hop_all = hop_all
assert type(machine_type) == int and type(machine_id) == int
......@@ -76,7 +78,8 @@ class MCMDataSet(BaseDataSet):
'hop_size': self.hop_size,
'power': power,
'normalize': normalize_raw,
'fmin': fmin
'fmin': fmin,
'hop_all': hop_all
}
training_set = MachineDataSet(machine_type, machine_id, mode='training', **kwargs)
......@@ -127,7 +130,8 @@ class MachineDataSet(torch.utils.data.Dataset):
hop_size=512,
power=2.0,
normalize=True,
fmin=0
fmin=0,
hop_all=False
):
assert mode in ['training', 'validation', 'testing']
......@@ -143,6 +147,7 @@ class MachineDataSet(torch.utils.data.Dataset):
self.machine_type = INVERSE_CLASS_MAP[machine_type]
self.machine_id = machine_id
self.fmin = fmin
self.hop_all = hop_all
if mode == 'training':
files = glob.glob(
......@@ -165,7 +170,7 @@ class MachineDataSet(torch.utils.data.Dataset):
files = sorted(files)
self.files = files
self.file_length = self.__load_preprocess_file__(files[0]).shape[-1]
self.num_samples_per_file = self.file_length - self.context + 1
self.num_samples_per_file = (self.file_length // self.context) if hop_all else (self.file_length - self.context + 1)
self.meta_data = self.__load_meta_data__(files)
self.data = self.__load_data__(files)
......@@ -175,7 +180,7 @@ class MachineDataSet(torch.utils.data.Dataset):
# get audio file index
item = item // self.num_samples_per_file
# load audio file and extract audio junk
offset = item * self.file_length + offset
offset = item * self.file_length + ((offset * self.context) if self.hop_all else offset)
observation = self.data[:, offset:offset + self.context]
# create data object
meta_data = self.meta_data[item].copy()
......
......@@ -156,7 +156,8 @@ def configuration():
'hop_size': hop_size,
'normalize_raw': normalize_raw,
'power': power,
'fmin': fmin
'fmin': fmin,
'hop_all': True
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment