@@ -556,6 +556,69 @@ def update_ema(self):
556556
557557 self .embed .data .copy_ (embed_normalized )
558558
559+ def update_ema_part (
560+ self ,
561+ flatten ,
562+ embed_onehot ,
563+ mask = None ,
564+ ema_update_weight : Tensor | Callable | None = None ,
565+ accum_ema_update = False
566+ ):
567+ if self .affine_param :
568+ codebook_std = self .codebook_variance .clamp (min = 1e-5 ).sqrt ()
569+ batch_std = self .batch_variance .clamp (min = 1e-5 ).sqrt ()
570+ flatten = (flatten - self .batch_mean ) * (codebook_std / batch_std ) + self .codebook_mean
571+
572+ if exists (mask ):
573+ embed_onehot [~ mask ] = 0.
574+
575+ cluster_size = embed_onehot .sum (dim = 1 )
576+ self .all_reduce_fn (cluster_size )
577+
578+ embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
579+ embed_sum = embed_sum .contiguous ()
580+ self .all_reduce_fn (embed_sum )
581+
582+ if callable (ema_update_weight ):
583+ ema_update_weight = ema_update_weight (embed_sum , cluster_size )
584+
585+ if accum_ema_update :
586+ accum_grad_ (self .cluster_size , cluster_size )
587+ accum_grad_ (self .embed_avg , embed_sum )
588+ else :
589+ ema_inplace (self .cluster_size , cluster_size , self .decay , ema_update_weight )
590+ ema_inplace (self .embed_avg , embed_sum , self .decay , ema_update_weight )
591+
592+ if not self .manual_ema_update :
593+ self .update_ema ()
594+ self .expire_codes_ (flatten )
595+
596+ def update_ema_indices (
597+ self ,
598+ x ,
599+ embed_ind ,
600+ mask = None ,
601+ ema_update_weight : Tensor | Callable | None = None ,
602+ accum_ema_update = False
603+ ):
604+ needs_codebook_dim = x .ndim < 4
605+ x = x .float ()
606+
607+ if needs_codebook_dim :
608+ x = rearrange (x , '... -> 1 ...' )
609+
610+ dtype = x .dtype
611+ flatten , unpack_one = pack_one (x , 'h * d' )
612+
613+ if exists (mask ):
614+ mask = repeat (mask , 'b n -> c (b h n)' , c = flatten .shape [0 ], h = flatten .shape [- 2 ] // (mask .shape [0 ] * mask .shape [1 ]))
615+
616+ embed_ind , _ = pack ([embed_ind ], 'h *' )
617+ embed_ind = embed_ind .masked_fill (embed_ind == - 1 , 0 )
618+ embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
619+
620+ self .update_ema_part (flatten , embed_onehot , mask = mask , ema_update_weight = ema_update_weight , accum_ema_update = accum_ema_update )
621+
559622 @autocast ('cuda' , enabled = False )
560623 def forward (
561624 self ,
@@ -565,8 +628,11 @@ def forward(
565628 freeze_codebook = False ,
566629 codebook_transform_fn : Callable | None = None ,
567630 ema_update_weight : Tensor | Callable | None = None ,
568- accum_ema_update = False
631+ accum_ema_update = False ,
632+ ema_update = None
569633 ):
634+ ema_update = default (ema_update , self .ema_update )
635+
570636 needs_codebook_dim = x .ndim < 4
571637 sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
572638
@@ -650,34 +716,8 @@ def forward(
650716 repeated_embed_ind = repeat (embed_ind , 'h b n -> h b n d' , d = embed .shape [- 1 ])
651717 quantize = repeated_embed .gather (- 2 , repeated_embed_ind )
652718
653- if self .training and self .ema_update and not freeze_codebook :
654-
655- if self .affine_param :
656- flatten = (flatten - self .batch_mean ) * (codebook_std / batch_std ) + self .codebook_mean
657-
658- if exists (mask ):
659- embed_onehot [~ mask ] = 0.
660-
661- cluster_size = embed_onehot .sum (dim = 1 )
662- self .all_reduce_fn (cluster_size )
663-
664- embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
665- embed_sum = embed_sum .contiguous ()
666- self .all_reduce_fn (embed_sum )
667-
668- if callable (ema_update_weight ):
669- ema_update_weight = ema_update_weight (embed_sum , cluster_size )
670-
671- if accum_ema_update :
672- accum_grad_ (self .cluster_size , cluster_size )
673- accum_grad_ (self .embed_avg , embed_sum )
674- else :
675- ema_inplace (self .cluster_size , cluster_size , self .decay , ema_update_weight )
676- ema_inplace (self .embed_avg , embed_sum , self .decay , ema_update_weight )
677-
678- if not self .manual_ema_update :
679- self .update_ema ()
680- self .expire_codes_ (x )
719+ if self .training and ema_update and not freeze_codebook :
720+ self .update_ema_part (flatten , embed_onehot , mask = mask , ema_update_weight = ema_update_weight , accum_ema_update = accum_ema_update )
681721
682722 if needs_codebook_dim :
683723 quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -934,6 +974,33 @@ def expire_codes_(self, x):
934974 x = self .maybe_split_heads_from_input (x )
935975 self ._codebook .expire_codes_ (x )
936976
977+ def update_ema_indices (self , x , indices , mask = None ):
978+ if self .accept_image_fmap :
979+ assert not exists (mask )
980+ height , width = x .shape [- 2 :]
981+ x = rearrange (x , 'b c h w -> b (h w) c' )
982+
983+ if not self .channel_last and not self .accept_image_fmap :
984+ x = rearrange (x , 'b d n -> b n d' )
985+
986+ x = self .project_in (x )
987+ x = self .maybe_split_heads_from_input (x )
988+ x = self ._codebook .transform_input (x )
989+
990+ if self .heads > 1 :
991+ if self .separate_codebook_per_head :
992+ indices = rearrange (indices , 'b n h -> h b n' )
993+ else :
994+ indices = rearrange (indices , 'b n h -> 1 (b h) n' )
995+
996+ if self .accept_image_fmap :
997+ indices = rearrange (indices , 'b h w ... -> b (h w) ...' )
998+
999+ if x .ndim == 2 : # only one token
1000+ indices = rearrange (indices , 'b ... -> b 1 ...' )
1001+
1002+ self ._codebook .update_ema_indices (x , indices , mask = mask )
1003+
9371004 def forward (
9381005 self ,
9391006 x ,
@@ -945,7 +1012,8 @@ def forward(
9451012 return_loss_breakdown = False ,
9461013 codebook_transform_fn : Callable | None = None ,
9471014 ema_update_weight : Tensor | None = None ,
948- accum_ema_update = False
1015+ accum_ema_update = False ,
1016+ ema_update = None
9491017 ):
9501018 orig_input , input_requires_grad = x , x .requires_grad
9511019
@@ -1003,7 +1071,8 @@ def forward(
10031071 freeze_codebook = freeze_codebook ,
10041072 codebook_transform_fn = codebook_transform_fn ,
10051073 ema_update_weight = ema_update_weight ,
1006- accum_ema_update = accum_ema_update
1074+ accum_ema_update = accum_ema_update ,
1075+ ema_update = ema_update
10071076 )
10081077
10091078 # quantize
0 commit comments