Skip to content

Commit fabc4e5

Browse files
seefunrwightman
authored andcommitted
Fixing tinyvit torchscript issue
1 parent bae949f commit fabc4e5

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

timm/models/tiny_vit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
__all__ = ['TinyVit']
1010
import math
1111
import itertools
12+
from typing import Dict
1213

1314
import torch
1415
import torch.nn as nn
1516
import torch.nn.functional as F
1617

1718
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18-
from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table
19+
from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table, _assert
1920
from ._builder import build_model_with_cfg
2021
from ._manipulate import checkpoint_seq
2122
from ._registry import register_model, generate_default_cfgs
@@ -178,6 +179,8 @@ def forward(self, x):
178179

179180

180181
class Attention(torch.nn.Module):
182+
attention_bias_cache: Dict[str, torch.Tensor]
183+
181184
def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)):
182185
super().__init__()
183186
assert isinstance(resolution, tuple) and len(resolution) == 2
@@ -304,7 +307,7 @@ def __init__(
304307
def forward(self, x):
305308
H, W = self.input_resolution
306309
B, L, C = x.shape
307-
assert L == H * W, "input feature has wrong size"
310+
_assert(L == H * W, f"input feature has wrong size, expect {H * W}, got {L}")
308311
res_x = x
309312
if H == self.window_size and W == self.window_size:
310313
x = self.attn(x)

0 commit comments

Comments
 (0)