Commit bd423e3e authored by Paul Primus's avatar Paul Primus
Browse files

add conv-autoencoder architecture

parent 1a46146c
......@@ -32,6 +32,15 @@ TRAINING_ID_MAP = {
5: [0, 2, 4, 6]
}
EVALUATION_ID_MAP = {
0: [1, 3, 5],
1: [1, 3, 5],
2: [1, 3, 5],
3: [5, 6, 7],
4: [4, 5, 6],
5: [1, 3, 5],
}
def enumerate_development_datasets():
typ_id = []
......@@ -41,6 +50,14 @@ def enumerate_development_datasets():
return typ_id
def enumerate_evaluation_datasets():
typ_id = []
for i in range(6):
for j in EVALUATION_ID_MAP[i]:
typ_id.append((i, j))
return typ_id
class MCMDataSet(BaseDataSet):
def __init__(
......@@ -134,7 +151,7 @@ class MachineDataSet(torch.utils.data.Dataset):
hop_all=False
):
assert mode in ['training', 'validation', 'testing']
assert mode in ['training', 'validation']
self.num_mel = num_mel
self.n_fft = n_fft
......@@ -149,22 +166,28 @@ class MachineDataSet(torch.utils.data.Dataset):
self.fmin = fmin
self.hop_all = hop_all
if machine_id in TRAINING_ID_MAP[machine_type]:
root_folder = 'dev_data'
elif machine_id in EVALUATION_ID_MAP[machine_type]:
root_folder = 'eval_data'
else:
raise AttributeError
if mode == 'training':
files = glob.glob(
os.path.join(
data_root, 'dev_data', self.machine_type, 'train', '*_id_{:02d}_*.wav'.format(machine_id)
data_root, root_folder, self.machine_type, 'train', '*id_{:02d}_*.wav'.format(machine_id)
)
)
elif mode == 'validation':
files = glob.glob(
os.path.join(
data_root, 'dev_data', self.machine_type, 'test', '*_id_{:02d}_*.wav'.format(machine_id)
data_root, root_folder, self.machine_type, 'test', '*id_{:02d}_*.wav'.format(machine_id)
)
)
elif mode == 'testing':
raise NotImplementedError
else:
raise AttributeError
assert len(files) > 0
files = sorted(files)
......@@ -221,6 +244,8 @@ class MachineDataSet(torch.utils.data.Dataset):
self.machine_id))
data = np.empty((self.num_mel, self.file_length * len(files)), dtype=np.float32)
for i, f in enumerate(files):
file = self.__load_preprocess_file__(f)
assert file.shape[1] == self.file_length
data[:, i * self.file_length:(i + 1) * self.file_length] = self.__load_preprocess_file__(f)
np.save(file_path, data)
......@@ -254,20 +279,36 @@ class MachineDataSet(torch.utils.data.Dataset):
meta_data = os.path.split(file_path)[-1].split('_')
machine_type = os.path.split(os.path.split(os.path.split(file_path)[0])[0])[1]
machine_type = CLASS_MAP[machine_type]
assert self.machine_type == INVERSE_CLASS_MAP[machine_type]
if len(meta_data) == 4:
y = 0 if meta_data[0] == 'normal' else 1
id = self.machine_id
part = int(meta_data[3].split('.')[0])
if meta_data[0] == 'normal':
y = 0
elif meta_data[0] == 'anomaly':
y = 1
else:
raise AttributeError
assert self.machine_id == int(meta_data[2])
elif len(meta_data) == 3:
y = -1
id = self.machine_id
part = int(meta_data[2].split('.')[0])
assert self.machine_id == int(meta_data[1])
else:
raise AttributeError
return {
'targets': y,
'machine_types': machine_type,
'machine_ids': id,
'machine_ids': self.machine_id,
'file_ids': os.sep.join(os.path.normpath(file_path).split(os.sep)[-4:])
}
if __name__ == '__main__':
for type_, id_ in enumerate_development_datasets():
_ = MachineDataSet(type_, id_, mode='training')
_ = MachineDataSet(type_, id_, mode='validation')
for type_, id_ in enumerate_evaluation_datasets():
_ = MachineDataSet(type_, id_, mode='training')
_ = MachineDataSet(type_, id_, mode='validation')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment