Skip to content

Commit cfe87ad

Browse files
committed
Fix HMM data caching issue
The Hidden Markov Model data generator created the required data upon the first invocation, then cached these files and loaded from them during subsequent calls. However the file name used for caching did not include any of the parameters used for data generation. This was a problem because the HMM unit test generated a much smaller number of samples that contain shorter signals, while the main HMM experiment used more samples and longer signals. If the user first executes the unit tests and not the main code then s/he ends up unintentionally using the shorter version of the data. This was the reason for having incorrect values for the HMM experiment in the original version of the published paper, which was later corrected. The HMM module now creates different cache files for different settings, making sure no such errors happen. Currently the file name only includes numeric settings, which was enough to avoid the error, though in the future it could be expanded to also include non-numeric ones as well.
1 parent c08604a commit cfe87ad

File tree

3 files changed

+38
-43
lines changed

3 files changed

+38
-43
lines changed

experiments/hmm/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def main(
6868

6969
# Load data
7070
hmm = HMM(n_folds=5, fold=fold, seed=seed)
71+
hmm.prepare_data()
7172

7273
print(f"Training classifier..")
7374

tests/datasets/test_hmm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ def test_hmm(
108108
fold=fold,
109109
num_workers=num_workers,
110110
seed=seed,
111+
test_size=10,
112+
signal_length=20
111113
)
112-
hmm.download(split="test", test_size=10, signal_length=20)
114+
hmm.download(split="test")
113115
x_test = hmm.preprocess(split="test")["x"]
114116
y_test = hmm.preprocess(split="test")["y"]
115117
assert tuple(x_test.shape) == (10, 20, 3)

tint/datasets/hmm.py

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ class HMM(DataModule):
4141
Ignored if n_folds is None. Default to ``None``
4242
num_workers (int): Number of workers for the loaders. Default to 0
4343
seed (int): For the random split. Default to 42
44+
train_size (int): Number of training samples to generate
45+
test_size (int): Number of test samples to generate
46+
signal_length (int): Length of the signal to generate
4447
4548
References:
4649
`Explaining Time Series Predictions with Dynamic Masks <https://arxiv.org/abs/2106.05303>`_
@@ -73,6 +76,9 @@ def __init__(
7376
fold: int = None,
7477
num_workers: int = 0,
7578
seed: int = 42,
79+
train_size: int = 800,
80+
test_size: int = 200,
81+
signal_length: int = 200,
7682
):
7783
super().__init__(
7884
data_dir=data_dir,
@@ -92,6 +98,10 @@ def __init__(
9298
self.scale = scale or [[0.1, 1.6, 0.5], [-0.1, -0.4, -1.5]]
9399
self.p0 = p0 or [0.5]
94100

101+
self.train_size = train_size
102+
self.test_size = test_size
103+
self.signal_length = signal_length
104+
95105
def init_dist(self):
96106
# Covariance matrix is constant across states but distribution
97107
# means change based on the state value
@@ -126,19 +136,23 @@ def next_state(previous_state, t):
126136
next_state = np.random.binomial(1, params)
127137
return next_state
128138

139+
def get_base_file_path(self, split):
140+
return os.path.join(
141+
self.data_dir,
142+
(f"{split}_{self.train_size}_{self.test_size}_{self.signal_length}_{self.n_signal}_" +
143+
f"{self.train}_{self.seed}_")
144+
)
145+
129146
def download(
130147
self,
131-
train_size: int = 800,
132-
test_size: int = 200,
133-
signal_length: int = 200,
134148
split: str = "train",
135149
):
136-
file = os.path.join(self.data_dir, f"{split}_")
150+
base_file_path = self.get_base_file_path(split)
137151

138152
if split == "train":
139-
count = train_size
153+
count = self.train_size
140154
elif split == "test":
141-
count = test_size
155+
count = self.test_size
142156
else:
143157
raise NotImplementedError
144158

@@ -159,7 +173,7 @@ def download(
159173
previous = np.random.binomial(1, self.p0)[0]
160174
delta_state = 0
161175
state_n = None
162-
for i in range(signal_length):
176+
for i in range(self.signal_length):
163177
next = self.next_state(previous, delta_state)
164178
state_n = next
165179

@@ -197,37 +211,23 @@ def download(
197211
all_states.append(states)
198212
label_logits.append(y_logits)
199213

200-
with open(
201-
os.path.join(self.data_dir, file + "features.npz"), "wb"
202-
) as fp:
214+
with open(base_file_path + "features.npz", "wb") as fp:
203215
pkl.dump(obj=features, file=fp)
204-
with open(
205-
os.path.join(self.data_dir, file + "labels.npz"), "wb"
206-
) as fp:
216+
with open(base_file_path + "labels.npz", "wb") as fp:
207217
pkl.dump(obj=labels, file=fp)
208-
with open(
209-
os.path.join(self.data_dir, file + "importance.npz"), "wb"
210-
) as fp:
218+
with open(base_file_path + "importance.npz", "wb") as fp:
211219
pkl.dump(obj=importance_score, file=fp)
212-
with open(
213-
os.path.join(self.data_dir, file + "states.npz"), "wb"
214-
) as fp:
220+
with open(base_file_path + "states.npz", "wb") as fp:
215221
pkl.dump(obj=all_states, file=fp)
216-
with open(
217-
os.path.join(self.data_dir, file + "labels_logits.npz"), "wb"
218-
) as fp:
222+
with open(base_file_path + "labels_logits.npz", "wb") as fp:
219223
pkl.dump(obj=label_logits, file=fp)
220224

221225
def preprocess(self, split: str = "train") -> dict:
222-
file = os.path.join(self.data_dir, f"{split}_")
226+
base_file_path = self.get_base_file_path(split)
223227

224-
with open(
225-
os.path.join(self.data_dir, file + "features.npz"), "rb"
226-
) as fp:
228+
with open(base_file_path + "features.npz", "rb") as fp:
227229
features = np.stack(pkl.load(file=fp))
228-
with open(
229-
os.path.join(self.data_dir, file + "labels.npz"), "rb"
230-
) as fp:
230+
with open(base_file_path + "labels.npz", "rb") as fp:
231231
labels = np.stack(pkl.load(file=fp))
232232

233233
return {
@@ -237,28 +237,20 @@ def preprocess(self, split: str = "train") -> dict:
237237

238238
def prepare_data(self):
239239
""""""
240-
if not os.path.exists(
241-
os.path.join(self.data_dir, "train_features.npz")
242-
):
240+
if not os.path.exists(self.get_base_file_path("train") + "features.npz"):
243241
self.download(split="train")
244-
if not os.path.exists(
245-
os.path.join(self.data_dir, "test_features.npz")
246-
):
242+
if not os.path.exists(self.get_base_file_path("test") + "features.npz"):
247243
self.download(split="test")
248244

249245
def true_saliency(self, split: str = "train") -> th.Tensor:
250-
file = os.path.join(self.data_dir, f"{split}_")
246+
base_file_path = self.get_base_file_path(split)
251247

252-
with open(
253-
os.path.join(self.data_dir, file + "features.npz"), "rb"
254-
) as fp:
248+
with open(base_file_path + "features.npz", "rb") as fp:
255249
features = np.stack(pkl.load(file=fp))
256250

257251
# Load the true states that define the truly salient features
258252
# and define A as in Section 3.2:
259-
with open(
260-
os.path.join(self.data_dir, file + "states.npz"), "rb"
261-
) as fp:
253+
with open(base_file_path + "states.npz", "rb") as fp:
262254
true_states = np.stack(pkl.load(file=fp))
263255
true_states += 1
264256

0 commit comments

Comments
 (0)