Closed
Description
tested on GPU GeForce RTX 2070
model: resnet18, traced with the following python script
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
with torch.jit.optimized_execution(True):
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(1, 3, 224, 224))
print(output[0, :5])
traced_script_module.save("traced_resnet_model.pt")
bazel run //cpp/benchmark --cxxopt="-DNDEBUG" --cxxopt="-DJIT" --cxxopt="-DTRT" <path to .pt>"(1 3 224 224)"
[JIT/TRT]: batch_size: 1
Average latency: 1.53636 ms
Average FPS: 650.889 fps
Latency Standard Deviation: 0.0887908
FPS Standard Deviation: 39.0779
(excluding initial warmup runs)
[JIT]: batch_size: 1
Average latency: 2.62867 ms
Average FPS: 380.421 fps
Latency Standard Deviation: 0.0781293
FPS Standard Deviation: 14.7786
(excluding initial warmup runs)
ok
and another time with --cxxopt="-DHALF"
[JIT/TRT]: batch_size: 1
Average latency: 0.921167 ms
Average FPS: 1085.58 fps
Latency Standard Deviation: 0.0322918
FPS Standard Deviation: 33.7609
(excluding initial warmup runs)
[JIT]: batch_size: 1
Average latency: 2.56868 ms
Average FPS: 389.305 fps
Latency Standard Deviation: 0.0382577
FPS Standard Deviation: 6.43474
(excluding initial warmup runs)
ok
change L154 in cpp/ptq/main.cpp
to match the dims above
bazel run //cpp/ptq -- <path to .pt> <path to cifar10 data>
[JIT model FP32]: batch_size: 1
Average latency: 2.64305 ms
Average FPS: 378.35 fps
Latency Standard Deviation: 0.0818512
FPS Standard Deviation: 14.3568
(excluding initial warmup runs)
[TRT quantized model]: batch_size: 1
Average latency: 1.42412 ms
Average FPS: 702.189 fps
Latency Standard Deviation: 0.0751042
FPS Standard Deviation: 31.2702
(excluding initial warmup runs)
I observe little appreciable speedup from int8 quantization compared to full precision JIT/TRT. In fact, half precision is much faster. I understand that hardware plays a role and my gpu is probably not built for int8 inference, but I did not expect it to fare worse than half precision.