Skip to content

add an example of aten2trt, fix batch norm pass #1685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ commands:
parameters:
torch-build:
type: string
default: "2.0.0.dev20230129+cu117"
default: "2.0.0.dev20230219+cu117"
torch-build-index:
type: string
default: "https://download.pytorch.org/whl/nightly/cu117"
Expand Down Expand Up @@ -1026,7 +1026,7 @@ parameters:
# Nightly platform config
torch-build:
type: string
default: "2.0.0.dev20230129+cu117"
default: "2.0.0.dev20230219+cu117"
torch-build-index:
type: string
default: "https://download.pytorch.org/whl/nightly/cu117"
Expand Down
1 change: 1 addition & 0 deletions examples/fx/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def run_configuration_benchmark(
input,
max_batch_size=conf.batch_size,
lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
explicit_batch_dimension=True,
)
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
else:
Expand Down
196 changes: 196 additions & 0 deletions examples/fx/lower_example_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import typing as t
from copy import deepcopy
from dataclasses import dataclass, field, replace

import torch
import torchvision
from torch_tensorrt.fx import compile
from torch_tensorrt.fx.utils import LowerPrecision


"""
The purpose of this example is to demostrate the onverall flow of lowering a PyTorch model
to TensorRT conveniently with lower.py.
"""


@dataclass
class Configuration:
"""
Specify the configuration used for fx2trt lowering and benchmark.

To extend, add a new configuration field to this class, and modify the
lowering or benchmark behavior in `run_configuration_benchmark()`
correspondingly.

It automatically prints all its values thanks to being a dataclass.
"""

# number of inferences to run
batch_iter: int

# Input batch size
batch_size: int

# Friendly name of the configuration
name: str = ""

# Whether to apply TRT lowering to the model before benchmarking
trt: bool = False

# Whether to apply engine holder to the lowered model
jit: bool = False

# Whether to enable FP16 mode for TRT lowering
fp16: bool = False

# Relative tolerance for accuracy check after lowering. -1 means do not
# check accuracy.
accuracy_rtol: float = -1 # disable


@dataclass
class Result:
"""Holds and computes the benchmark results.

Holds raw essential benchmark result values like duration.
Also computes results that can be derived from the raw essential values
(QPS), in the form of auto properties.

"""

module: torch.nn.Module = field(repr=False)
input: t.Any = field(repr=False)
conf: Configuration
time_sec: float
accuracy_res: t.Optional[bool] = None

@property
def time_per_iter_ms(self) -> float:
return self.time_sec * 1.0e3

@property
def qps(self) -> float:
return self.conf.batch_size / self.time_sec

def format(self) -> str:
return (
f"== Benchmark Result for: {self.conf}\n"
f"BS: {self.conf.batch_size}, "
f"Time per iter: {self.time_per_iter_ms:.2f}ms, "
f"QPS: {self.qps:.2f}, "
f"Accuracy: {self.accuracy_res} (rtol={self.conf.accuracy_rtol})"
)


@torch.inference_mode()
def benchmark(
model,
inputs,
batch_iter: int,
batch_size: int,
) -> None:
"""
Run fx2trt lowering and benchmark the given model according to the
specified benchmark configuration. Prints the benchmark result for each
configuration at the end of the run.
"""

model = model.cuda().eval()
inputs = [x.cuda() for x in inputs]

# benchmark base configuration
conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)

configurations = [
# Baseline
replace(conf, name="CUDA Eager", trt=False),
# FP16
replace(
conf,
name="TRT FP16 Eager",
trt=True,
jit=False,
fp16=True,
accuracy_rtol=1e-2,
),
]

results = [
run_configuration_benchmark(deepcopy(model), inputs, conf_)
for conf_ in configurations
]

for res in results:
print(res.format())


def benchmark_torch_function(iters: int, f, *args) -> float:
"""Estimates the average time duration for a single inference call in second

If the input is batched, then the estimation is for the batches inference call.

Args:
iters: number of inference iterations to run
f: a function to perform a single inference call

Returns:
estimated average time duration in second for a single inference call
"""
with torch.inference_mode():
f(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
print("== Start benchmark iterations")
with torch.inference_mode():
start_event.record()
for _ in range(iters):
f(*args)
end_event.record()
torch.cuda.synchronize()
print("== End benchmark iterations")
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters


def run_configuration_benchmark(
module,
input,
conf: Configuration,
) -> Result:
"""
Runs `module` through lowering logic and benchmark the module before and
after lowering.
"""
print(f"=== Running benchmark for: {conf}", "green")
time = -1.0

if conf.fp16:
module = module.half()
input = [i.half() for i in input]

if not conf.trt:
# Run eager mode benchmark
time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))
elif not conf.jit:
# Run lowering eager mode benchmark
lowered_module = compile(
module,
input,
max_batch_size=conf.batch_size,
lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
explicit_batch_dimension=True,
is_aten=True,
)
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
else:
print("Lowering with JIT is not available!", "red")

result = Result(module=module, input=input, conf=conf, time_sec=time)
return result


if __name__ == "__main__":
test_model = torchvision.models.resnet18(pretrained=True)
input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 50, 128)
2 changes: 2 additions & 0 deletions py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,15 @@ def run(self):
"torch_tensorrt.fx.passes",
"torch_tensorrt.fx.tools",
"torch_tensorrt.fx.tracer.acc_tracer",
"torch_tensorrt.fx.tracer.dispatch_tracer",
]
package_dir = {
"torch_tensorrt.fx": "torch_tensorrt/fx",
"torch_tensorrt.fx.converters": "torch_tensorrt/fx/converters",
"torch_tensorrt.fx.passes": "torch_tensorrt/fx/passes",
"torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools",
"torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer",
"torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer",
}

with open("README.md", "r", encoding="utf-8") as fh:
Expand Down
11 changes: 11 additions & 0 deletions py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph
torch.ops.aten.max_pool3d_with_indices.default,
torch.ops.aten.native_batch_norm.default,
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten._native_batch_norm_legit_no_training.default,
):
modified = True
if len(n.users) != 1:
Expand All @@ -185,6 +186,16 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph
new_args = list(n.args)
new_args.append(False)
new_args = tuple(new_args)
elif (
n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
):
new_op = torch.ops.aten.batch_norm
new_args = list(n.args)
new_args.append(False)
# _native_batch_norm_legit_no_training doesn't take in a training arg (assumed to be false)
# but batchnorm takes in a training arg at position 5.
new_args.insert(5, False)
new_args = tuple(new_args)

getitem_node = next(iter(n.users))
with module.graph.inserting_after(getitem_node):
Expand Down