Skip to content

Commit 884ef88

Browse files
YassineYousfirwightman
authored andcommitted
fix all SDPA dropouts
1 parent b500cae commit 884ef88

File tree

14 files changed

+22
-22
lines changed

14 files changed

+22
-22
lines changed

timm/models/beit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
155155
x = F.scaled_dot_product_attention(
156156
q, k, v,
157157
attn_mask=rel_pos_bias,
158-
dropout_p=self.attn_drop.p,
158+
dropout_p=self.attn_drop.p if self.training else 0.,
159159
)
160160
else:
161161
q = q * self.scale

timm/models/cait.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def forward(self, x):
5050
if self.fused_attn:
5151
x_cls = torch.nn.functional.scaled_dot_product_attention(
5252
q, k, v,
53-
dropout_p=self.attn_drop.p,
53+
dropout_p=self.attn_drop.p if self.training else 0.,
5454
)
5555
else:
5656
q = q * self.scale

timm/models/eva.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def forward(
126126
x = F.scaled_dot_product_attention(
127127
q, k, v,
128128
attn_mask=attn_mask,
129-
dropout_p=self.attn_drop.p,
129+
dropout_p=self.attn_drop.p if self.training else 0.,
130130
)
131131
else:
132132
q = q * self.scale

timm/models/fastvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
514514
if self.fused_attn:
515515
x = torch.nn.functional.scaled_dot_product_attention(
516516
q, k, v,
517-
dropout_p=self.attn_drop.p if self.training else 0.0,
517+
dropout_p=self.attn_drop.p if self.training else 0.,
518518
)
519519
else:
520520
q = q * self.scale

timm/models/maxxvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
190190
k.transpose(-1, -2).contiguous(),
191191
v.transpose(-1, -2).contiguous(),
192192
attn_mask=attn_bias,
193-
dropout_p=self.attn_drop.p,
193+
dropout_p=self.attn_drop.p if self.training else 0.,
194194
).transpose(-1, -2).reshape(B, -1, H, W)
195195
else:
196196
q = q * self.scale
@@ -259,7 +259,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
259259
x = torch.nn.functional.scaled_dot_product_attention(
260260
q, k, v,
261261
attn_mask=attn_bias,
262-
dropout_p=self.attn_drop.p,
262+
dropout_p=self.attn_drop.p if self.training else 0.,
263263
)
264264
else:
265265
q = q * self.scale

timm/models/metaformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def forward(self, x):
198198
if self.fused_attn:
199199
x = F.scaled_dot_product_attention(
200200
q, k, v,
201-
dropout_p=self.attn_drop.p,
201+
dropout_p=self.attn_drop.p if self.training else 0.,
202202
)
203203
else:
204204
attn = (q @ k.transpose(-2, -1)) * self.scale

timm/models/nest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.)
5959
def forward(self, x):
6060
"""
6161
x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
62-
"""
62+
"""
6363
B, T, N, C = x.shape
6464
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
6565
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
6666
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
6767

6868
if self.fused_attn:
69-
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
69+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
7070
else:
7171
q = q * self.scale
7272
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
@@ -330,7 +330,7 @@ def __init__(
330330
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
331331
# number of blocks along edge of image
332332
self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
333-
333+
334334
# Patch embedding
335335
self.patch_embed = PatchEmbed(
336336
img_size=img_size,

timm/models/pvt_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def forward(self, x, feat_size: List[int]):
130130
k, v = kv.unbind(0)
131131

132132
if self.fused_attn:
133-
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
133+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
134134
else:
135135
q = q * self.scale
136136
attn = q @ k.transpose(-2, -1)

timm/models/swin_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None):
164164
x = torch.nn.functional.scaled_dot_product_attention(
165165
q, k, v,
166166
attn_mask=attn_mask,
167-
dropout_p=self.attn_drop.p,
167+
dropout_p=self.attn_drop.p if self.training else 0.,
168168
)
169169
else:
170170
q = q * self.scale

timm/models/twins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(self, x, size: Size_):
7575
if self.fused_attn:
7676
x = F.scaled_dot_product_attention(
7777
q, k, v,
78-
dropout_p=self.attn_drop.p,
78+
dropout_p=self.attn_drop.p if self.training else 0.,
7979
)
8080
else:
8181
q = q * self.scale
@@ -172,7 +172,7 @@ def forward(self, x, size: Size_):
172172
if self.fused_attn:
173173
x = torch.nn.functional.scaled_dot_product_attention(
174174
q, k, v,
175-
dropout_p=self.attn_drop.p,
175+
dropout_p=self.attn_drop.p if self.training else 0.,
176176
)
177177
else:
178178
q = q * self.scale

0 commit comments

Comments
 (0)