Skip to content

feat: support aten.isnan converter #2711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,23 @@ def aten_ops_isinf(
)


@dynamo_tensorrt_converter(torch.ops.aten.isnan.default)
def aten_ops_isnan(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.isnan(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
def aten_ops_add(
Expand Down
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,23 @@ def scalar_tensor(
identity_layer = ctx.net.add_identity(tensor)
set_layer_name(identity_layer, target, name, source_ir)
return identity_layer.get_output(0)


def isnan(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
# False for NaN elements since NaN is not equal to anything, including itself.
equality_result = impl.elementwise.eq(
ctx, target, source_ir, f"{name}_eq_nan", input, input
)

# Invert equality_result to get a mask where NaN values are marked as True.
nan_values_mask = logical_not(
ctx, target, source_ir, f"{name}_logical_not", equality_result
)

return nan_values_mask
82 changes: 82 additions & 0 deletions tests/py/dynamo/conversion/test_isnan_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestIsNanConverter(DispatchTestCase):
@parameterized.expand(
[
(
torch.tensor(
[
1.23,
float("nan"),
-4.56,
float("inf"),
float("-inf"),
-100.0,
float("nan"),
0.13,
-0.13,
3.14159265,
]
),
),
]
)
def test_isnan_float(self, data):
class isnan(nn.Module):
def forward(self, input):
return torch.ops.aten.isnan.default(input)

inputs = [data]
self.run_test(
isnan(),
inputs,
output_dtypes=[torch.bool],
)

@parameterized.expand(
[
(torch.full((2, 2), float("nan"), dtype=torch.float32),),
(torch.full((3, 10, 5), float("nan"), dtype=torch.float32),),
(torch.randn((5, 10, 5), dtype=torch.float32),),
]
)
def test_isnan_dim(self, data):
class isnan(nn.Module):
def forward(self, input):
return torch.ops.aten.isnan.default(input)

inputs = [data]
self.run_test(
isnan(),
inputs,
output_dtypes=[torch.bool],
)

@parameterized.expand(
[
((10,), torch.int, 0, 5),
((1, 20), torch.int32, -10, 10),
((2, 3, 4), torch.int, -5, 5),
]
)
def test_isnan_int(self, input_shape, dtype, low, high):
class isnan(nn.Module):
def forward(self, input):
return torch.ops.aten.isnan.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
isnan(),
inputs,
output_dtypes=[torch.bool],
)


if __name__ == "__main__":
run_tests()