@@ -41,6 +41,9 @@ class HMM(DataModule):
41
41
Ignored if n_folds is None. Default to ``None``
42
42
num_workers (int): Number of workers for the loaders. Default to 0
43
43
seed (int): For the random split. Default to 42
44
+ train_size (int): Number of training samples to generate
45
+ test_size (int): Number of test samples to generate
46
+ signal_length (int): Length of the signal to generate
44
47
45
48
References:
46
49
`Explaining Time Series Predictions with Dynamic Masks <https://arxiv.org/abs/2106.05303>`_
@@ -73,6 +76,9 @@ def __init__(
73
76
fold : int = None ,
74
77
num_workers : int = 0 ,
75
78
seed : int = 42 ,
79
+ train_size : int = 800 ,
80
+ test_size : int = 200 ,
81
+ signal_length : int = 200 ,
76
82
):
77
83
super ().__init__ (
78
84
data_dir = data_dir ,
@@ -92,6 +98,10 @@ def __init__(
92
98
self .scale = scale or [[0.1 , 1.6 , 0.5 ], [- 0.1 , - 0.4 , - 1.5 ]]
93
99
self .p0 = p0 or [0.5 ]
94
100
101
+ self .train_size = train_size
102
+ self .test_size = test_size
103
+ self .signal_length = signal_length
104
+
95
105
def init_dist (self ):
96
106
# Covariance matrix is constant across states but distribution
97
107
# means change based on the state value
@@ -126,19 +136,23 @@ def next_state(previous_state, t):
126
136
next_state = np .random .binomial (1 , params )
127
137
return next_state
128
138
139
+ def get_base_file_path (self , split ):
140
+ return os .path .join (
141
+ self .data_dir ,
142
+ (f"{ split } _{ self .train_size } _{ self .test_size } _{ self .signal_length } _{ self .n_signal } _" +
143
+ f"{ self .train } _{ self .seed } _" )
144
+ )
145
+
129
146
def download (
130
147
self ,
131
- train_size : int = 800 ,
132
- test_size : int = 200 ,
133
- signal_length : int = 200 ,
134
148
split : str = "train" ,
135
149
):
136
- file = os . path . join ( self .data_dir , f" { split } _" )
150
+ base_file_path = self .get_base_file_path ( split )
137
151
138
152
if split == "train" :
139
- count = train_size
153
+ count = self . train_size
140
154
elif split == "test" :
141
- count = test_size
155
+ count = self . test_size
142
156
else :
143
157
raise NotImplementedError
144
158
@@ -159,7 +173,7 @@ def download(
159
173
previous = np .random .binomial (1 , self .p0 )[0 ]
160
174
delta_state = 0
161
175
state_n = None
162
- for i in range (signal_length ):
176
+ for i in range (self . signal_length ):
163
177
next = self .next_state (previous , delta_state )
164
178
state_n = next
165
179
@@ -197,37 +211,23 @@ def download(
197
211
all_states .append (states )
198
212
label_logits .append (y_logits )
199
213
200
- with open (
201
- os .path .join (self .data_dir , file + "features.npz" ), "wb"
202
- ) as fp :
214
+ with open (base_file_path + "features.npz" , "wb" ) as fp :
203
215
pkl .dump (obj = features , file = fp )
204
- with open (
205
- os .path .join (self .data_dir , file + "labels.npz" ), "wb"
206
- ) as fp :
216
+ with open (base_file_path + "labels.npz" , "wb" ) as fp :
207
217
pkl .dump (obj = labels , file = fp )
208
- with open (
209
- os .path .join (self .data_dir , file + "importance.npz" ), "wb"
210
- ) as fp :
218
+ with open (base_file_path + "importance.npz" , "wb" ) as fp :
211
219
pkl .dump (obj = importance_score , file = fp )
212
- with open (
213
- os .path .join (self .data_dir , file + "states.npz" ), "wb"
214
- ) as fp :
220
+ with open (base_file_path + "states.npz" , "wb" ) as fp :
215
221
pkl .dump (obj = all_states , file = fp )
216
- with open (
217
- os .path .join (self .data_dir , file + "labels_logits.npz" ), "wb"
218
- ) as fp :
222
+ with open (base_file_path + "labels_logits.npz" , "wb" ) as fp :
219
223
pkl .dump (obj = label_logits , file = fp )
220
224
221
225
def preprocess (self , split : str = "train" ) -> dict :
222
- file = os . path . join ( self .data_dir , f" { split } _" )
226
+ base_file_path = self .get_base_file_path ( split )
223
227
224
- with open (
225
- os .path .join (self .data_dir , file + "features.npz" ), "rb"
226
- ) as fp :
228
+ with open (base_file_path + "features.npz" , "rb" ) as fp :
227
229
features = np .stack (pkl .load (file = fp ))
228
- with open (
229
- os .path .join (self .data_dir , file + "labels.npz" ), "rb"
230
- ) as fp :
230
+ with open (base_file_path + "labels.npz" , "rb" ) as fp :
231
231
labels = np .stack (pkl .load (file = fp ))
232
232
233
233
return {
@@ -237,28 +237,20 @@ def preprocess(self, split: str = "train") -> dict:
237
237
238
238
def prepare_data (self ):
239
239
""""""
240
- if not os .path .exists (
241
- os .path .join (self .data_dir , "train_features.npz" )
242
- ):
240
+ if not os .path .exists (self .get_base_file_path ("train" ) + "features.npz" ):
243
241
self .download (split = "train" )
244
- if not os .path .exists (
245
- os .path .join (self .data_dir , "test_features.npz" )
246
- ):
242
+ if not os .path .exists (self .get_base_file_path ("test" ) + "features.npz" ):
247
243
self .download (split = "test" )
248
244
249
245
def true_saliency (self , split : str = "train" ) -> th .Tensor :
250
- file = os . path . join ( self .data_dir , f" { split } _" )
246
+ base_file_path = self .get_base_file_path ( split )
251
247
252
- with open (
253
- os .path .join (self .data_dir , file + "features.npz" ), "rb"
254
- ) as fp :
248
+ with open (base_file_path + "features.npz" , "rb" ) as fp :
255
249
features = np .stack (pkl .load (file = fp ))
256
250
257
251
# Load the true states that define the truly salient features
258
252
# and define A as in Section 3.2:
259
- with open (
260
- os .path .join (self .data_dir , file + "states.npz" ), "rb"
261
- ) as fp :
253
+ with open (base_file_path + "states.npz" , "rb" ) as fp :
262
254
true_states = np .stack (pkl .load (file = fp ))
263
255
true_states += 1
264
256
0 commit comments