Skip to content

🐛 [Bug] Encountered bug when using TRTorch #264

Closed
@Trouble404

Description

@Trouble404

Bug Description

can't execute the Python example code from master built trtorch due to can't pass right args to Python trtorch.compile API
trtorch.compile(script_model, compile_spec)
TypeError: compile_graph(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch::jit::Module, arg1: trtorch._C.CompileSpec) -> torch::jit::Module

Invoked with: <torch._C.ScriptModule object at 0x7fcba18ae7f0>, <trtorch._C.CompileSpec object at 0x7fcbd1b8bcb0>

To Reproduce

Steps to reproduce the behavior:

  1. build dependencies by tarball distributions
    2.build python wheel
    3.try to compile Lenet from doc [https://nvidia.github.io/TRTorch/tutorials/getting_started.html]

import torch.nn as nn
import torch.nn.functional as F
import torch.jit
import torch
import trtorch

class LeNetFeatExtractor(nn.Module):
def init(self):
super(LeNetFeatExtractor, self).init()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)

def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    return x

class LeNetClassifier(nn.Module):
def init(self):
super(LeNetClassifier, self).init()
self.fc1 = nn.Linear(16 * 6 * 6, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    x = torch.flatten(x,1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

class LeNet(nn.Module):
def init(self):
super(LeNet, self).init()
self.feat = LeNetFeatExtractor()
self.classifer = LeNetClassifier()

def forward(self, x):
    x = self.feat(x)
    x = self.classifer(x)
    return x

model = LeNet()
model.eval()
script_model = torch.jit.script(model)
script_model.save("lenet_scripted.ts")
script_model.eval()

compile_spec = {
"input_shapes": [
(1, 3, 224, 224), # Static input shape for input #1
{
"min": (1, 1, 16, 16),
"opt": (1, 1, 32, 32),
"max": (1, 1, 64, 64)
} # Dynamic input shape for input #2
],
"device": {
#"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
"dla_core": 0, # (DLA only) Target dla core id to run engine
"allow_gpu_fallback": False, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
"refit": False, # enable refit
"debug": False, # enable debuggable engine
"strict_types": False, # kernels should strictly run in operating precision
"capability": trtorch.EngineCapability.default, # Restrict kernel selection to safe gpu kernels or safe dla kernels
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
"workspace_size": 0, # Maximum size of workspace given to TensorRT
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
}

trt_ts_module = trtorch.compile(script_model, compile_spec)
print("done")

Expected behavior

build trt_ts_module right,

Environment

Build information about the TRTorch compiler can be found by turning on debug messages

  • PyTorch Version (e.g., 1.7.1):
  • CPU Architecture:Intel(R) Xeon(R) CPU E5-2640 v4 @ 2.40GHz
  • OS (e.g., Linux):Ubuntu 18.04.5
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): bazel build //:libtrtorch --compilation_mode opt --distdir third_party/dist_dir/[x86_64-linux-gnu | aarch64-linux-gnu] cd py python setup.py install
  • Are you using local sources or building from archives: by tarball
  • Python version: 3.7.9
  • CUDA version:11.0
  • GPU models and configuration:TITAN XP 1080ti
  • Any other relevant information:

Additional context

the 0.1.0 doc's python example can't run in master built trtorch

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions