Skip to content

Commit ca26e51

Browse files
authored
Rician Noise Transform (#2066)
* Added RicianNoise transform Signed-off-by: Lyndon Boone <[email protected]> * Use ensure_tuple_rep in channel-wise RandRicianNoise transform Signed-off-by: Lyndon Boone <[email protected]> * Added RandRicianNoised transform Signed-off-by: Lyndon Boone <[email protected]> * Autofixed coding style errors Signed-off-by: Lyndon Boone <[email protected]> * Added paper reference for RandRicianNoise in docstring Signed-off-by: Lyndon Boone <[email protected]> * Added unit test for RandRicianNoise transform Signed-off-by: Lyndon Boone <[email protected]> * Added unit test for RandRicianNoised transform Signed-off-by: Lyndon Boone <[email protected]> * Fixed mypy typing issues Signed-off-by: Lyndon Boone <[email protected]>
1 parent 35cd9df commit ca26e51

File tree

5 files changed

+266
-1
lines changed

5 files changed

+266
-1
lines changed

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
RandGaussianSharpen,
7979
RandGaussianSmooth,
8080
RandHistogramShift,
81+
RandRicianNoise,
8182
RandScaleIntensity,
8283
RandShiftIntensity,
8384
RandStdShiftIntensity,
@@ -123,6 +124,9 @@
123124
RandHistogramShiftd,
124125
RandHistogramShiftD,
125126
RandHistogramShiftDict,
127+
RandRicianNoised,
128+
RandRicianNoiseD,
129+
RandRicianNoiseDict,
126130
RandScaleIntensityd,
127131
RandScaleIntensityD,
128132
RandScaleIntensityDict,

monai/transforms/intensity/array.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,17 @@
2424
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
2525
from monai.transforms.transform import RandomizableTransform, Transform
2626
from monai.transforms.utils import rescale_array
27-
from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size
27+
from monai.utils import (
28+
PT_BEFORE_1_7,
29+
InvalidPyTorchVersionError,
30+
dtype_torch_to_numpy,
31+
ensure_tuple_rep,
32+
ensure_tuple_size,
33+
)
2834

2935
__all__ = [
3036
"RandGaussianNoise",
37+
"RandRicianNoise",
3138
"ShiftIntensity",
3239
"RandShiftIntensity",
3340
"StdShiftIntensity",
@@ -85,6 +92,82 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor,
8592
return img + self._noise.astype(dtype)
8693

8794

95+
class RandRicianNoise(RandomizableTransform):
96+
"""
97+
Add Rician noise to image.
98+
Rician noise in MRI is the result of performing a magnitude operation on complex
99+
data with Gaussian noise of the same variance in both channels, as described in `Noise in Magnitude Magnetic Resonance Images
100+
<https://doi.org/10.1002/cmr.a.20124>`_. This transform is adapted from
101+
`DIPY<https://github.com/dipy/dipy>`_. See also: `The rician distribution of noisy mri data
102+
<https://doi.org/10.1002/mrm.1910340618>`_.
103+
104+
Args:
105+
prob: Probability to add Rician noise.
106+
mean: Mean or "centre" of the Gaussian distributions sampled to make up
107+
the Rician noise.
108+
std: Standard deviation (spread) of the Gaussian distributions sampled
109+
to make up the Rician noise.
110+
channel_wise: If True, treats each channel of the image separately.
111+
relative: If True, the spread of the sampled Gaussian distributions will
112+
be std times the standard deviation of the image or channel's intensity
113+
histogram.
114+
sample_std: If True, sample the spread of the Gaussian distributions
115+
uniformly from 0 to std.
116+
"""
117+
118+
def __init__(
119+
self,
120+
prob: float = 0.1,
121+
mean: Union[Sequence[float], float] = 0.0,
122+
std: Union[Sequence[float], float] = 1.0,
123+
channel_wise: bool = False,
124+
relative: bool = False,
125+
sample_std: bool = True,
126+
) -> None:
127+
RandomizableTransform.__init__(self, prob)
128+
self.prob = prob
129+
self.mean = mean
130+
self.std = std
131+
self.channel_wise = channel_wise
132+
self.relative = relative
133+
self.sample_std = sample_std
134+
self._noise1 = None
135+
self._noise2 = None
136+
137+
def _add_noise(self, img: Union[torch.Tensor, np.ndarray], mean: float, std: float):
138+
im_shape = img.shape
139+
_std = self.R.uniform(0, std) if self.sample_std else std
140+
self._noise1 = self.R.normal(mean, _std, size=im_shape)
141+
self._noise2 = self.R.normal(mean, _std, size=im_shape)
142+
if self._noise1 is None or self._noise2 is None:
143+
raise AssertionError
144+
dtype = dtype_torch_to_numpy(img.dtype) if isinstance(img, torch.Tensor) else img.dtype
145+
return np.sqrt((img + self._noise1.astype(dtype)) ** 2 + self._noise2.astype(dtype) ** 2)
146+
147+
def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]:
148+
"""
149+
Apply the transform to `img`.
150+
"""
151+
super().randomize(None)
152+
if not self._do_transform:
153+
return img
154+
if self.channel_wise:
155+
_mean = ensure_tuple_rep(self.mean, len(img))
156+
_std = ensure_tuple_rep(self.std, len(img))
157+
for i, d in enumerate(img):
158+
img[i] = self._add_noise(d, mean=_mean[i], std=_std[i] * d.std() if self.relative else _std[i])
159+
else:
160+
if not isinstance(self.mean, (int, float)):
161+
raise AssertionError("If channel_wise is False, mean must be a float or int number.")
162+
if not isinstance(self.std, (int, float)):
163+
raise AssertionError("If channel_wise is False, std must be a float or int number.")
164+
std = self.std * img.std() if self.relative else self.std
165+
if not isinstance(std, (int, float)):
166+
raise AssertionError
167+
img = self._add_noise(img, mean=self.mean, std=std)
168+
return img
169+
170+
88171
class ShiftIntensity(Transform):
89172
"""
90173
Shift intensity uniformly for the entire image with specified `offset`.

monai/transforms/intensity/dictionary.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
MaskIntensity,
3030
NormalizeIntensity,
3131
RandBiasField,
32+
RandRicianNoise,
3233
ScaleIntensity,
3334
ScaleIntensityRange,
3435
ScaleIntensityRangePercentiles,
@@ -41,6 +42,7 @@
4142

4243
__all__ = [
4344
"RandGaussianNoised",
45+
"RandRicianNoised",
4446
"ShiftIntensityd",
4547
"RandShiftIntensityd",
4648
"ScaleIntensityd",
@@ -152,6 +154,65 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
152154
return d
153155

154156

157+
class RandRicianNoised(RandomizableTransform, MapTransform):
158+
"""
159+
Dictionary-based version :py:class:`monai.transforms.RandRicianNoise`.
160+
Add Rician noise to image. This transform assumes all the expected fields have same shape.
161+
162+
Args:
163+
keys: Keys of the corresponding items to be transformed.
164+
See also: :py:class:`monai.transforms.compose.MapTransform`
165+
global_prob: Probability to add Rician noise to the dictionary.
166+
prob: Probability to add Rician noise to each item in the dictionary,
167+
once asserted that noise will be added to the dictionary at all.
168+
mean: Mean or "centre" of the Gaussian distributions sampled to make up
169+
the Rician noise.
170+
std: Standard deviation (spread) of the Gaussian distributions sampled
171+
to make up the Rician noise.
172+
channel_wise: If True, treats each channel of the image separately.
173+
relative: If True, the spread of the sampled Gaussian distributions will
174+
be std times the standard deviation of the image or channel's intensity
175+
histogram.
176+
sample_std: If True, sample the spread of the Gaussian distributions
177+
uniformly from 0 to std.
178+
allow_missing_keys: Don't raise exception if key is missing.
179+
"""
180+
181+
def __init__(
182+
self,
183+
keys: KeysCollection,
184+
global_prob: float = 0.1,
185+
prob: float = 1.0,
186+
mean: Union[Sequence[float], float] = 0.0,
187+
std: Union[Sequence[float], float] = 1.0,
188+
channel_wise: bool = False,
189+
relative: bool = False,
190+
sample_std: bool = True,
191+
allow_missing_keys: bool = False,
192+
) -> None:
193+
MapTransform.__init__(self, keys, allow_missing_keys)
194+
RandomizableTransform.__init__(self, global_prob)
195+
self.rand_rician_noise = RandRicianNoise(
196+
prob,
197+
mean,
198+
std,
199+
channel_wise,
200+
relative,
201+
sample_std,
202+
)
203+
204+
def __call__(
205+
self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]]
206+
) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]:
207+
d = dict(data)
208+
super().randomize(None)
209+
if not self._do_transform:
210+
return d
211+
for key in self.key_iterator(d):
212+
d[key] = self.rand_rician_noise(d[key])
213+
return d
214+
215+
155216
class ShiftIntensityd(MapTransform):
156217
"""
157218
Dictionary-based wrapper of :py:class:`monai.transforms.ShiftIntensity`.
@@ -958,6 +1019,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
9581019

9591020

9601021
RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised
1022+
RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised
9611023
ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd
9621024
RandShiftIntensityD = RandShiftIntensityDict = RandShiftIntensityd
9631025
StdShiftIntensityD = StdShiftIntensityDict = StdShiftIntensityd

tests/test_rand_rician_noise.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
from parameterized import parameterized
16+
17+
from monai.transforms import RandRicianNoise
18+
from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D
19+
20+
21+
class TestRandRicianNoise(NumpyImageTestCase2D):
22+
@parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)])
23+
def test_correct_results(self, _, mean, std):
24+
seed = 0
25+
rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std)
26+
rician_fn.set_random_state(seed)
27+
noised = rician_fn(self.imt)
28+
np.random.seed(seed)
29+
np.random.random()
30+
_std = np.random.uniform(0, std)
31+
expected = np.sqrt(
32+
(self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2
33+
+ np.random.normal(mean, _std, size=self.imt.shape) ** 2
34+
)
35+
np.testing.assert_allclose(expected, noised, atol=1e-5)
36+
37+
38+
class TestRandRicianNoiseTorch(TorchImageTestCase2D):
39+
@parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)])
40+
def test_correct_results(self, _, mean, std):
41+
seed = 0
42+
rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std)
43+
rician_fn.set_random_state(seed)
44+
noised = rician_fn(self.imt)
45+
np.random.seed(seed)
46+
np.random.random()
47+
_std = np.random.uniform(0, std)
48+
expected = np.sqrt(
49+
(self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2
50+
+ np.random.normal(mean, _std, size=self.imt.shape) ** 2
51+
)
52+
np.testing.assert_allclose(expected, noised, atol=1e-5)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()

tests/test_rand_rician_noised.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
from parameterized import parameterized
16+
17+
from monai.transforms import RandRicianNoised
18+
from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D
19+
20+
TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1]
21+
TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5]
22+
TEST_CASES = [TEST_CASE_0, TEST_CASE_1]
23+
24+
seed = 0
25+
26+
27+
def test_numpy_or_torch(keys, mean, std, imt):
28+
rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std)
29+
rician_fn.set_random_state(seed)
30+
rician_fn.rand_rician_noise.set_random_state(seed)
31+
noised = rician_fn({k: imt for k in keys})
32+
np.random.seed(seed)
33+
np.random.random()
34+
np.random.seed(seed)
35+
for k in keys:
36+
np.random.random()
37+
_std = np.random.uniform(0, std)
38+
expected = np.sqrt(
39+
(imt + np.random.normal(mean, _std, size=imt.shape)) ** 2
40+
+ np.random.normal(mean, _std, size=imt.shape) ** 2
41+
)
42+
np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5)
43+
44+
45+
# Test with numpy
46+
class TestRandRicianNoisedNumpy(NumpyImageTestCase2D):
47+
@parameterized.expand(TEST_CASES)
48+
def test_correct_results(self, _, keys, mean, std):
49+
test_numpy_or_torch(keys, mean, std, self.imt)
50+
51+
52+
# Test with torch
53+
class TestRandRicianNoisedTorch(TorchImageTestCase2D):
54+
@parameterized.expand(TEST_CASES)
55+
def test_correct_results(self, _, keys, mean, std):
56+
test_numpy_or_torch(keys, mean, std, self.imt)
57+
58+
59+
if __name__ == "__main__":
60+
unittest.main()

0 commit comments

Comments
 (0)