Skip to content

Commit 29f94f4

Browse files
committed
continue exploration with antigravity, allowing for manual ema update with input and indices outside of forward
1 parent 8f52cd8 commit 29f94f4

File tree

2 files changed

+136
-31
lines changed

2 files changed

+136
-31
lines changed

tests/test_manual_ema.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from vector_quantize_pytorch import VectorQuantize
3+
4+
def test_manual_ema_update():
5+
6+
vq1 = VectorQuantize(
7+
dim = 256,
8+
codebook_size = 512
9+
)
10+
11+
vq2 = VectorQuantize(
12+
dim = 256,
13+
codebook_size = 512
14+
)
15+
16+
vq2.load_state_dict(vq1.state_dict())
17+
18+
x = torch.randn(1, 1024, 256)
19+
mask = torch.randint(0, 2, (1, 1024)).bool()
20+
21+
vq1.train()
22+
quantize1, indices1, _ = vq1(x, mask = mask)
23+
24+
vq2.train()
25+
quantize2, indices2, _ = vq2(x, mask = mask, ema_update = False)
26+
27+
assert torch.allclose(quantize1, quantize2)
28+
assert torch.equal(indices1, indices2)
29+
30+
assert not torch.allclose(vq1._codebook.embed_avg, vq2._codebook.embed_avg)
31+
32+
vq2.update_ema_indices(x, indices2, mask = mask)
33+
34+
assert torch.allclose(vq1._codebook.cluster_size, vq2._codebook.cluster_size)
35+
assert torch.allclose(vq1._codebook.embed_avg, vq2._codebook.embed_avg)
36+
assert torch.allclose(vq1.codebook, vq2.codebook)

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 100 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,69 @@ def update_ema(self):
556556

557557
self.embed.data.copy_(embed_normalized)
558558

559+
def update_ema_part(
560+
self,
561+
flatten,
562+
embed_onehot,
563+
mask = None,
564+
ema_update_weight: Tensor | Callable | None = None,
565+
accum_ema_update = False
566+
):
567+
if self.affine_param:
568+
codebook_std = self.codebook_variance.clamp(min = 1e-5).sqrt()
569+
batch_std = self.batch_variance.clamp(min = 1e-5).sqrt()
570+
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
571+
572+
if exists(mask):
573+
embed_onehot[~mask] = 0.
574+
575+
cluster_size = embed_onehot.sum(dim = 1)
576+
self.all_reduce_fn(cluster_size)
577+
578+
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
579+
embed_sum = embed_sum.contiguous()
580+
self.all_reduce_fn(embed_sum)
581+
582+
if callable(ema_update_weight):
583+
ema_update_weight = ema_update_weight(embed_sum, cluster_size)
584+
585+
if accum_ema_update:
586+
accum_grad_(self.cluster_size, cluster_size)
587+
accum_grad_(self.embed_avg, embed_sum)
588+
else:
589+
ema_inplace(self.cluster_size, cluster_size, self.decay, ema_update_weight)
590+
ema_inplace(self.embed_avg, embed_sum, self.decay, ema_update_weight)
591+
592+
if not self.manual_ema_update:
593+
self.update_ema()
594+
self.expire_codes_(flatten)
595+
596+
def update_ema_indices(
597+
self,
598+
x,
599+
embed_ind,
600+
mask = None,
601+
ema_update_weight: Tensor | Callable | None = None,
602+
accum_ema_update = False
603+
):
604+
needs_codebook_dim = x.ndim < 4
605+
x = x.float()
606+
607+
if needs_codebook_dim:
608+
x = rearrange(x, '... -> 1 ...')
609+
610+
dtype = x.dtype
611+
flatten, unpack_one = pack_one(x, 'h * d')
612+
613+
if exists(mask):
614+
mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1]))
615+
616+
embed_ind, _ = pack([embed_ind], 'h *')
617+
embed_ind = embed_ind.masked_fill(embed_ind == -1, 0)
618+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
619+
620+
self.update_ema_part(flatten, embed_onehot, mask = mask, ema_update_weight = ema_update_weight, accum_ema_update = accum_ema_update)
621+
559622
@autocast('cuda', enabled = False)
560623
def forward(
561624
self,
@@ -565,8 +628,11 @@ def forward(
565628
freeze_codebook = False,
566629
codebook_transform_fn: Callable | None = None,
567630
ema_update_weight: Tensor | Callable | None = None,
568-
accum_ema_update = False
631+
accum_ema_update = False,
632+
ema_update = None
569633
):
634+
ema_update = default(ema_update, self.ema_update)
635+
570636
needs_codebook_dim = x.ndim < 4
571637
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
572638

@@ -650,34 +716,8 @@ def forward(
650716
repeated_embed_ind = repeat(embed_ind, 'h b n -> h b n d', d = embed.shape[-1])
651717
quantize = repeated_embed.gather(-2, repeated_embed_ind)
652718

653-
if self.training and self.ema_update and not freeze_codebook:
654-
655-
if self.affine_param:
656-
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
657-
658-
if exists(mask):
659-
embed_onehot[~mask] = 0.
660-
661-
cluster_size = embed_onehot.sum(dim = 1)
662-
self.all_reduce_fn(cluster_size)
663-
664-
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
665-
embed_sum = embed_sum.contiguous()
666-
self.all_reduce_fn(embed_sum)
667-
668-
if callable(ema_update_weight):
669-
ema_update_weight = ema_update_weight(embed_sum, cluster_size)
670-
671-
if accum_ema_update:
672-
accum_grad_(self.cluster_size, cluster_size)
673-
accum_grad_(self.embed_avg, embed_sum)
674-
else:
675-
ema_inplace(self.cluster_size, cluster_size, self.decay, ema_update_weight)
676-
ema_inplace(self.embed_avg, embed_sum, self.decay, ema_update_weight)
677-
678-
if not self.manual_ema_update:
679-
self.update_ema()
680-
self.expire_codes_(x)
719+
if self.training and ema_update and not freeze_codebook:
720+
self.update_ema_part(flatten, embed_onehot, mask = mask, ema_update_weight = ema_update_weight, accum_ema_update = accum_ema_update)
681721

682722
if needs_codebook_dim:
683723
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
@@ -934,6 +974,33 @@ def expire_codes_(self, x):
934974
x = self.maybe_split_heads_from_input(x)
935975
self._codebook.expire_codes_(x)
936976

977+
def update_ema_indices(self, x, indices, mask = None):
978+
if self.accept_image_fmap:
979+
assert not exists(mask)
980+
height, width = x.shape[-2:]
981+
x = rearrange(x, 'b c h w -> b (h w) c')
982+
983+
if not self.channel_last and not self.accept_image_fmap:
984+
x = rearrange(x, 'b d n -> b n d')
985+
986+
x = self.project_in(x)
987+
x = self.maybe_split_heads_from_input(x)
988+
x = self._codebook.transform_input(x)
989+
990+
if self.heads > 1:
991+
if self.separate_codebook_per_head:
992+
indices = rearrange(indices, 'b n h -> h b n')
993+
else:
994+
indices = rearrange(indices, 'b n h -> 1 (b h) n')
995+
996+
if self.accept_image_fmap:
997+
indices = rearrange(indices, 'b h w ... -> b (h w) ...')
998+
999+
if x.ndim == 2: # only one token
1000+
indices = rearrange(indices, 'b ... -> b 1 ...')
1001+
1002+
self._codebook.update_ema_indices(x, indices, mask = mask)
1003+
9371004
def forward(
9381005
self,
9391006
x,
@@ -945,7 +1012,8 @@ def forward(
9451012
return_loss_breakdown = False,
9461013
codebook_transform_fn: Callable | None = None,
9471014
ema_update_weight: Tensor | None = None,
948-
accum_ema_update = False
1015+
accum_ema_update = False,
1016+
ema_update = None
9491017
):
9501018
orig_input, input_requires_grad = x, x.requires_grad
9511019

@@ -1003,7 +1071,8 @@ def forward(
10031071
freeze_codebook = freeze_codebook,
10041072
codebook_transform_fn = codebook_transform_fn,
10051073
ema_update_weight = ema_update_weight,
1006-
accum_ema_update = accum_ema_update
1074+
accum_ema_update = accum_ema_update,
1075+
ema_update = ema_update
10071076
)
10081077

10091078
# quantize

0 commit comments

Comments
 (0)