Skip to content

Commit 72b227d

Browse files
authored
Merge pull request #750 from drjinying/master
Specify "interpolation" mode in vision_transformer's resize_pos_embed
2 parents 2907c1f + 20b2d4b commit 72b227d

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

timm/models/nest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def resize_pos_embed(posemb, posemb_new):
377377
size_new = int(math.sqrt(num_blocks_new*seq_length_new))
378378
# First change to (1, C, H, W)
379379
posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2)
380-
posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bilinear')
380+
posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bicubic', align_corners=False)
381381
# Now change to new (1, T, N, C)
382382
posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new)))
383383
return posemb

timm/models/vision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
494494
assert len(gs_new) >= 2
495495
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
496496
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
497-
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
497+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
498498
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
499499
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
500500
return posemb

0 commit comments

Comments
 (0)