Skip to content

Commit 4ba91bf

Browse files
authored
make mypy more strict for prototype datasets (#4513)
* make mypy more strict for prototype datasets * fix code format * apply strictness only to datasets * fix more mypy issues * cleanup * fix mnist annotations * refactor celeba * warn on redundant casts * remove redundant cast * simplify annotation * fix import
1 parent 9407b45 commit 4ba91bf

File tree

16 files changed

+146
-119
lines changed

16 files changed

+146
-119
lines changed

mypy.ini

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,23 @@ files = torchvision
44
show_error_codes = True
55
pretty = True
66
allow_redefinition = True
7+
warn_redundant_casts = True
8+
9+
[mypy-torchvision.prototype.datasets.*]
10+
11+
; untyped definitions and calls
12+
disallow_untyped_defs = True
13+
14+
; None and Optional handling
15+
no_implicit_optional = True
16+
17+
; warnings
18+
warn_unused_ignores = True
19+
warn_return_any = True
20+
warn_unreachable = True
21+
22+
; miscellaneous strictness flags
23+
allow_redefinition = True
724

825
[mypy-torchvision.io._video_opt.*]
926

torchvision/datasets/usps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, Callable, cast, Optional, Tuple
2+
from typing import Any, Callable, Optional, Tuple
33

44
import numpy as np
55
from PIL import Image
@@ -63,7 +63,7 @@ def __init__(
6363
raw_data = [line.decode().split() for line in fp.readlines()]
6464
tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
6565
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
66-
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
66+
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
6767
targets = [int(d[0]) - 1 for d in raw_data]
6868

6969
self.data = imgs

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
8282
return category, id
8383

8484
def _collate_and_decode_sample(
85-
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
85+
self,
86+
data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]],
87+
*,
88+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
8689
) -> Dict[str, Any]:
8790
key, image_data, ann_data = data
8891
category, _ = key
@@ -117,11 +120,11 @@ def _make_datapipe(
117120
images_dp, anns_dp = resource_dps
118121

119122
images_dp = TarArchiveReader(images_dp)
120-
images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image)
123+
images_dp = Filter(images_dp, self._is_not_background_image)
121124
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
122125

123126
anns_dp = TarArchiveReader(anns_dp)
124-
anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann)
127+
anns_dp = Filter(anns_dp, self._is_ann)
125128

126129
dp = KeyZipper(
127130
images_dp,
@@ -136,7 +139,7 @@ def _make_datapipe(
136139
def _generate_categories(self, root: pathlib.Path) -> List[str]:
137140
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
138141
dp = TarArchiveReader(dp)
139-
dp: IterDataPipe = Filter(dp, self._is_not_background_image)
142+
dp = Filter(dp, self._is_not_background_image)
140143
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
141144

142145

@@ -185,7 +188,7 @@ def _make_datapipe(
185188
) -> IterDataPipe[Dict[str, Any]]:
186189
dp = resource_dps[0]
187190
dp = TarArchiveReader(dp)
188-
dp: IterDataPipe = Filter(dp, self._is_not_rogue_file)
191+
dp = Filter(dp, self._is_not_rogue_file)
189192
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
190193
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
191194

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import csv
22
import io
3-
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence
44

55
import torch
66
from torchdata.datapipes.iter import (
@@ -23,37 +23,38 @@
2323
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor
2424

2525

26-
class CelebACSVParser(IterDataPipe):
26+
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
27+
28+
29+
class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
2730
def __init__(
2831
self,
29-
datapipe,
32+
datapipe: IterDataPipe[Tuple[Any, io.IOBase]],
3033
*,
31-
has_header,
32-
):
34+
fieldnames: Optional[Sequence[str]] = None,
35+
) -> None:
3336
self.datapipe = datapipe
34-
self.has_header = has_header
35-
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)
37+
self.fieldnames = fieldnames
3638

37-
def __iter__(self):
39+
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
3840
for _, file in self.datapipe:
3941
file = (line.decode() for line in file)
4042

41-
if self.has_header:
43+
if self.fieldnames:
44+
fieldnames = self.fieldnames
45+
else:
4246
# The first row is skipped, because it only contains the number of samples
4347
next(file)
4448

45-
# Empty field names are filtered out, because some files have an extr white space after the header
49+
# Empty field names are filtered out, because some files have an extra white space after the header
4650
# line, which is recognized as extra column
47-
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name]
51+
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
4852
# Some files do not include a label for the image ID column
4953
if fieldnames[0] != "image_id":
5054
fieldnames.insert(0, "image_id")
5155

52-
for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams):
53-
yield line.pop("image_id"), line
54-
else:
55-
for line in csv.reader(file, **self._fmtparams):
56-
yield line[0], line[1:]
56+
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
57+
yield line.pop("image_id"), line
5758

5859

5960
class CelebA(Dataset):
@@ -104,13 +105,10 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
104105
"2": "test",
105106
}
106107

107-
def _filter_split(self, data: Tuple[str, str], *, split):
108-
_, split_id = data
109-
return self._SPLIT_ID_TO_NAME[split_id[0]] == split
108+
def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool:
109+
return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split
110110

111-
def _collate_anns(
112-
self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...]
113-
) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]:
111+
def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]:
114112
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
115113
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)
116114

@@ -127,7 +125,7 @@ def _collate_and_decode_sample(
127125

128126
image = decoder(buffer) if decoder else buffer
129127

130-
identity = torch.tensor(int(ann["identity"][0]))
128+
identity = int(ann["identity"]["identity"])
131129
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
132130
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
133131
landmarks = {
@@ -153,24 +151,24 @@ def _make_datapipe(
153151
) -> IterDataPipe[Dict[str, Any]]:
154152
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
155153

156-
splits_dp = CelebACSVParser(splits_dp, has_header=False)
157-
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
154+
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
155+
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
158156
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
159157

160158
images_dp = ZipArchiveReader(images_dp)
161159

162-
anns_dp: IterDataPipe = Zipper(
160+
anns_dp = Zipper(
163161
*[
164-
CelebACSVParser(dp, has_header=has_header)
165-
for dp, has_header in (
166-
(identities_dp, False),
167-
(attributes_dp, True),
168-
(bboxes_dp, True),
169-
(landmarks_dp, True),
162+
CelebACSVParser(dp, fieldnames=fieldnames)
163+
for dp, fieldnames in (
164+
(identities_dp, ("image_id", "identity")),
165+
(attributes_dp, None),
166+
(bboxes_dp, None),
167+
(landmarks_dp, None),
170168
)
171169
]
172170
)
173-
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns)
171+
anns_dp = Mapper(anns_dp, self._collate_anns)
174172

175173
dp = KeyZipper(
176174
splits_dp,

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import io
44
import pathlib
55
import pickle
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast
77

88
import numpy as np
99
import torch
@@ -56,7 +56,7 @@ def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -
5656

5757
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
5858
_, file = data
59-
return pickle.load(file, encoding="latin1")
59+
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
6060

6161
def _collate_and_decode(
6262
self,
@@ -86,19 +86,19 @@ def _make_datapipe(
8686
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
8787
) -> IterDataPipe[Dict[str, Any]]:
8888
dp = resource_dps[0]
89-
dp: IterDataPipe = TarArchiveReader(dp)
90-
dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config))
91-
dp: IterDataPipe = Mapper(dp, self._unpickle)
89+
dp = TarArchiveReader(dp)
90+
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
91+
dp = Mapper(dp, self._unpickle)
9292
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
9393
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
9494
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
9595

9696
def _generate_categories(self, root: pathlib.Path) -> List[str]:
9797
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
9898
dp = TarArchiveReader(dp)
99-
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
100-
dp: IterDataPipe = Mapper(dp, self._unpickle)
101-
return next(iter(dp))[self._CATEGORIES_KEY]
99+
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
100+
dp = Mapper(dp, self._unpickle)
101+
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
102102

103103

104104
class Cifar10(_CifarBase):
@@ -133,9 +133,9 @@ class Cifar100(_CifarBase):
133133
_META_FILE_NAME = "meta"
134134
_CATEGORIES_KEY = "fine_label_names"
135135

136-
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
136+
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
137137
path = pathlib.Path(data[0])
138-
return path.name == config.split
138+
return path.name == cast(str, config.split)
139139

140140
@property
141141
def info(self) -> DatasetInfo:

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import pathlib
33
import re
4-
from typing import Any, Callable, Dict, List, Optional, Tuple
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
55

66
import torch
77
from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
@@ -44,11 +44,11 @@ def info(self) -> DatasetInfo:
4444

4545
@property
4646
def category_to_wnid(self) -> Dict[str, str]:
47-
return self.info.extra.category_to_wnid
47+
return cast(Dict[str, str], self.info.extra.category_to_wnid)
4848

4949
@property
5050
def wnid_to_category(self) -> Dict[str, str]:
51-
return self.info.extra.wnid_to_category
51+
return cast(Dict[str, str], self.info.extra.wnid_to_category)
5252

5353
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5454
if config.split == "train":
@@ -152,20 +152,23 @@ def _make_datapipe(
152152
"n03710721": "tank suit",
153153
}
154154

155-
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]:
155+
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
156156
resources = self.resources(self.default_config)
157157
devkit_dp = resources[1].to_datapipe(root / self.name)
158158
devkit_dp = TarArchiveReader(devkit_dp)
159159
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
160160

161161
meta = next(iter(devkit_dp))[1]
162162
synsets = read_mat(meta, squeeze_me=True)["synsets"]
163-
categories_and_wnids = [
164-
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
165-
for _, wnid, category, _, num_children, *_ in synsets
166-
# if num_children > 0, we are looking at a superclass that has no direct instance
167-
if num_children == 0
168-
]
163+
categories_and_wnids = cast(
164+
List[Tuple[str, ...]],
165+
[
166+
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
167+
for _, wnid, category, _, num_children, *_ in synsets
168+
# if num_children > 0, we are looking at a superclass that has no direct instance
169+
if num_children == 0
170+
],
171+
)
169172
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
170173

171174
return categories_and_wnids

0 commit comments

Comments
 (0)