complement_dataset.py 3.01 KB
Newer Older
Paul Primus's avatar
Paul Primus committed
1
2
import os
import torch.utils.data
Paul Primus's avatar
Paul Primus committed
3
4
from dcase2020_task2.data_sets import BaseDataSet, CLASS_MAP, INVERSE_CLASS_MAP, TRAINING_ID_MAP, ALL_ID_MAP
from dcase2020_task2.data_sets import MachineDataSet
Paul Primus's avatar
add    
Paul Primus committed
5
import numpy as np
Paul Primus's avatar
Paul Primus committed
6

Paul Primus's avatar
Paul Primus committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
VALID_TYPES = {

    'strict': {
        0: [1, 2, 5],
        1: [0, 2, 5],
        2: [0, 1, 5],
        5: [0, 1, 2],
        3: [4],
        4: [3],
    },
    'loose': {
        0: [0, 1, 2, 5],
        1: [1, 0, 2, 5],
        2: [2, 0, 1, 5],
        5: [5, 0, 1, 2],
        3: [3, 4],
        4: [4, 3],
    },
    'very_loose': {
        0: [0, 1, 2, 3, 4, 5],
        1: [0, 1, 2, 3, 4, 5],
        2: [0, 1, 2, 3, 4, 5],
        5: [0, 1, 2, 3, 4, 5],
        3: [0, 1, 2, 3, 4, 5],
        4: [0, 1, 2, 3, 4, 5],
    },

34
35
36
}


Paul Primus's avatar
Paul Primus committed
37
38
39
40
41
42
43
44
45
46
47
48
49
class ComplementMCMDataSet(BaseDataSet):

    def __init__(
            self,
            machine_type,
            machine_id,
            data_root=os.path.join(os.path.expanduser('~'), 'shared', 'dcase2020_task2'),
            context=5,
            num_mel=128,
            n_fft=1024,
            hop_size=512,
            power=1.0,
            fmin=0,
Paul Primus's avatar
Paul Primus committed
50
51
52
53
            normalize_raw=True,
            normalize_spec=False,
            hop_all=False,
            valid_types='strict'
Paul Primus's avatar
Paul Primus committed
54
    ):
55
56

        assert type(machine_type) == int and type(machine_id) == int
Paul Primus's avatar
Paul Primus committed
57
58
        assert machine_id >= 0
        assert machine_type >= 0
59

Paul Primus's avatar
Paul Primus committed
60
61
62
63
64
65
66
67
        self.data_root = data_root
        self.context = context
        self.num_mel = num_mel
        self.n_fft = n_fft
        self.hop_size = hop_size
        self.power = power
        self.fmin = fmin
        self.hop_all = hop_all
68
        self.normalize_raw = normalize_raw
Paul Primus's avatar
Paul Primus committed
69
70
        self.normalize_spec = normalize_spec
        self.valid_types = valid_types
Paul Primus's avatar
Paul Primus committed
71
72
73
74
75
76
77

        kwargs = {
            'data_root': self.data_root,
            'context': self.context,
            'num_mel': self.num_mel,
            'n_fft': self.n_fft,
            'hop_size': self.hop_size,
78
79
80
            'power': self.power,
            'normalize': self.normalize_raw,
            'fmin': self.fmin,
Paul Primus's avatar
Paul Primus committed
81
82
            'hop_all': self.hop_all,
            'normalize_spec': self.normalize_spec
Paul Primus's avatar
Paul Primus committed
83
84
85
        }

        training_sets = []
86
87

        data = []
Paul Primus's avatar
Paul Primus committed
88
        for type_ in VALID_TYPES[self.valid_types][machine_type]:
89
            for id_ in ALL_ID_MAP[type_]:
Paul Primus's avatar
Paul Primus committed
90
                if type_ != machine_type or id_ != machine_id:
Paul Primus's avatar
Paul Primus committed
91
                    t = MachineDataSet(type_, id_, mode='training', **kwargs)
92
                    data.append(t.data)
Paul Primus's avatar
Paul Primus committed
93
                    training_sets.append(t)
94
        data = np.concatenate(data, axis=-1)
Paul Primus's avatar
Paul Primus committed
95

96
97
98
        self.mean = data.mean(axis=1, keepdims=True)
        self.std = data.std(axis=1, keepdims=True)
        del data
Paul Primus's avatar
Paul Primus committed
99
100
101
102
103
104
105
106
107
108
109

        self.training_set = torch.utils.data.ConcatDataset(training_sets)

    @property
    def observation_shape(self) -> tuple:
        return 1, self.num_mel, self.context

    def training_data_set(self):
        return self.training_set

    def validation_data_set(self):
Paul Primus's avatar
Paul Primus committed
110
        raise NotImplementedError
Paul Primus's avatar
Paul Primus committed
111
112
113
114

    def mean_std(self):
        return self.mean, self.std