-
Notifications
You must be signed in to change notification settings - Fork 174
Open
Description
While making the torch TAPIR model compatible with Torchscript tracing is easy by changing TAPIR.forward() in https://github.com/google-deepmind/tapnet/blob/main/torch/tapir_model.py#L196-L209 from
out = dict(
occlusion=torch.mean(
torch.stack(trajectories['occlusion'][p::p]), dim=0
),
tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
expected_dist=torch.mean(
torch.stack(trajectories['expected_dist'][p::p]), dim=0
),
unrefined_occlusion=trajectories['occlusion'][:-1],
unrefined_tracks=trajectories['tracks'][:-1],
unrefined_expected_dist=trajectories['expected_dist'][:-1],
)
return out
to
class Output(NamedTuple):
occlusion: torch.tensor
tracks: torch.tensor
expected_dist: torch.tensor
out = Output(torch.mean(torch.stack(trajectories['occlusion'][p::p]), dim=0),
torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
torch.mean(torch.stack(trajectories['expected_dist'][p::p]), dim=0)
)
return out
(assuming it is OK to eliminate unrefined_ from the output), so that
model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('bootstapir_checkpoint.pt'))
model = model.to(torch.device('cpu'))
model.eval()
dummy_input_frames = torch.randn(1, 32, 256, 256, 3, dtype=torch.float32, device = torch.device('cpu'))
dummy_input_query_points = torch.randn(1, 20, 3, dtype=torch.float32, device = torch.device('cpu'))
scriptModule = torch.jit.trace(model, (dummy_input_frames, dummy_input_query_points))
torch.jit.save(scriptModule, 'bootstapir_checkpoint.ptc')
succeeds, it is not so easy to make it Torchscript scripting compatible.
scriptModule = torch.jit.script(model)
fails with
Module 'BlockV2' has no attribute 'proj_conv' :
File "C:\tapnet\tapnet\torch\nets.py", line 278
x = torch.relu(x)
if self.use_projection:
shortcut = self.proj_conv(x)
~~~~~~~~~~~~~~ <--- HERE
How to make the model Torchscript scripting compatible?
Metadata
Metadata
Assignees
Labels
No labels