Skip to content

Commit 460eba7

Browse files
committed
Work around casting issue with combination of native torch AMP and torchscript for Linear layers
1 parent 5f4b607 commit 460eba7

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

timm/models/layers/classifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.nn import functional as F
77

88
from .adaptive_avgmax_pool import SelectAdaptivePool2d
9+
from .linear import Linear
910

1011

1112
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
@@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
2122
elif use_conv:
2223
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
2324
else:
24-
fc = nn.Linear(num_pooled_features, num_classes, bias=True)
25+
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
26+
fc = Linear(num_pooled_features, num_classes, bias=True)
2527
return global_pool, fc
2628

2729

timm/models/layers/linear.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
""" Linear layer (alternate definition)
2+
"""
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import nn as nn
6+
7+
8+
class Linear(nn.Linear):
9+
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10+
11+
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12+
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13+
"""
14+
def forward(self, input: torch.Tensor) -> torch.Tensor:
15+
if torch.jit.is_scripting():
16+
return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype))
17+
else:
18+
return F.linear(input, self.weight, self.bias)

train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,6 @@ def main():
367367
if args.torchscript:
368368
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
369369
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
370-
# FIXME I ran into a bug w/ AMP + torchscript + Linear layers
371370
model = torch.jit.script(model)
372371

373372
optimizer = create_optimizer(args, model)

0 commit comments

Comments
 (0)