|
9 | 9 | __all__ = ['TinyVit'] |
10 | 10 | import math |
11 | 11 | import itertools |
| 12 | +from typing import Dict |
12 | 13 |
|
13 | 14 | import torch |
14 | 15 | import torch.nn as nn |
15 | 16 | import torch.nn.functional as F |
16 | 17 |
|
17 | 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
18 | | -from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table |
| 19 | +from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table, _assert |
19 | 20 | from ._builder import build_model_with_cfg |
20 | 21 | from ._manipulate import checkpoint_seq |
21 | 22 | from ._registry import register_model, generate_default_cfgs |
@@ -178,6 +179,8 @@ def forward(self, x): |
178 | 179 |
|
179 | 180 |
|
180 | 181 | class Attention(torch.nn.Module): |
| 182 | + attention_bias_cache: Dict[str, torch.Tensor] |
| 183 | + |
181 | 184 | def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)): |
182 | 185 | super().__init__() |
183 | 186 | assert isinstance(resolution, tuple) and len(resolution) == 2 |
@@ -304,7 +307,7 @@ def __init__( |
304 | 307 | def forward(self, x): |
305 | 308 | H, W = self.input_resolution |
306 | 309 | B, L, C = x.shape |
307 | | - assert L == H * W, "input feature has wrong size" |
| 310 | + _assert(L == H * W, f"input feature has wrong size, expect {H * W}, got {L}") |
308 | 311 | res_x = x |
309 | 312 | if H == self.window_size and W == self.window_size: |
310 | 313 | x = self.attn(x) |
|
0 commit comments