Skip to content

Torchscript compatibility #83

@ssandler-cat

Description

@ssandler-cat

While making the torch TAPIR model compatible with Torchscript tracing is easy by changing TAPIR.forward() in https://github.com/google-deepmind/tapnet/blob/main/torch/tapir_model.py#L196-L209 from

    out = dict(
        occlusion=torch.mean(
            torch.stack(trajectories['occlusion'][p::p]), dim=0
        ),
        tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
        expected_dist=torch.mean(
            torch.stack(trajectories['expected_dist'][p::p]), dim=0
        ),
        unrefined_occlusion=trajectories['occlusion'][:-1],
        unrefined_tracks=trajectories['tracks'][:-1],
        unrefined_expected_dist=trajectories['expected_dist'][:-1],
    )

    return out

to

    class Output(NamedTuple):
        occlusion: torch.tensor
        tracks: torch.tensor
        expected_dist: torch.tensor

    out = Output(torch.mean(torch.stack(trajectories['occlusion'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['expected_dist'][p::p]), dim=0)
                )

    return out

(assuming it is OK to eliminate unrefined_ from the output), so that

model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('bootstapir_checkpoint.pt'))
model = model.to(torch.device('cpu'))
model.eval()
dummy_input_frames = torch.randn(1, 32, 256, 256, 3, dtype=torch.float32, device = torch.device('cpu'))
dummy_input_query_points = torch.randn(1, 20, 3, dtype=torch.float32, device = torch.device('cpu'))    
scriptModule = torch.jit.trace(model, (dummy_input_frames, dummy_input_query_points))
torch.jit.save(scriptModule, 'bootstapir_checkpoint.ptc')

succeeds, it is not so easy to make it Torchscript scripting compatible.

scriptModule = torch.jit.script(model)

fails with

Module 'BlockV2' has no attribute 'proj_conv' :
  File "C:\tapnet\tapnet\torch\nets.py", line 278
    x = torch.relu(x)
    if self.use_projection:
      shortcut = self.proj_conv(x)
                 ~~~~~~~~~~~~~~ <--- HERE

How to make the model Torchscript scripting compatible?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions