Skip to content

Conversation

@tongdaxu
Copy link

@tongdaxu tongdaxu commented Dec 7, 2025

GaussianQuant: VQ-VAE using Gaussian VAE

  • Conceptually:
    • Train a Gaussian VAE and convert it into VQ-VAE without much loss!
    • The Gaussian VAE need to satisfy $D_{KL}(q(Z^i|X)||N(0,1))\approx\log_2$ codebooksize. This is achieved by special designed loss for KL divergence. See details in gaussian_quant.py:
      kl2 = 1.4426 * 0.5 * (torch.pow(mu, 2) + var - 1.0 - logvar)
      # -1, dim, codebook number
      kl2 = kl2.reshape(-1,self.dim,codebook_num)
      kl2 = torch.sum(kl2,dim=1) # sum over dim
      
      # compute mean, min, max of kl divergence
      kl2_mean, kl2_min, kl2_max = torch.mean(kl2), torch.min(kl2), torch.max(kl2)
      ge = (kl2 > self.log_n_samples + self.tolerance).type(kl2.dtype) * self.lam_max
      eq = (kl2 <= self.log_n_samples + self.tolerance).type(kl2.dtype) * (
          kl2 >= self.log_n_samples - self.tolerance
      ).type(kl2.dtype)
      le = (kl2 < self.log_n_samples - self.tolerance).type(kl2.dtype) * self.lam_min
      
      # reweight kl divergence according to its relation to log2 codebook_size
      kl_loss = ge * kl2 + eq * kl2 + le * kl2
      kl_loss = torch.mean(kl_loss) * self.lam
    • The codebook vector is purely Gaussian noise. The codebook is not updated. During inference, we just find the noise vector closest to mu:
      q_normal_dist = Normal(mu_q[:, None, :], std_q[:, None, :])
      log_ratios = (
          q_normal_dist.log_prob(self.prior_samples[None])
          - self.normal_log_prob[None] * self.beta
      )
      perturbed = torch.sum(log_ratios, dim=2)
      argmax_indices = torch.argmax(perturbed, dim=1)
      zhat[i : i + bs] = torch.index_select(self.prior_samples, 0, argmax_indices)
      indices[i : i + bs] = argmax_indices
  • Practically:
    • The GaussianQuant takes input of any shape. But you need to specify which dimension to quantize.
    • The GaussianQuant has exact same interface as VQ-VAE. You can determine codebook dim, codebook size, and GaussianQuant will implicitly infer codebook number.
    • Unlike other VQ-VAE, the input Z should contains twice channels, the first split is mean of Gaussian VAE, the second split is the logvar.
    • Usage example can be found in ./examples/autoencoder_gq.py
        python ./examples/autoencoder_gq.py
        # gives: rec loss VQ: 0.121 | kl loss: 352867.281 | active %: 61.328
    • Usage example can also be found below:
        import torch
        from vector_quantize_pytorch import GaussianQuant
        mu = torch.zeros([1,6,64,64]) # B, C, H, W,
        logvar = torch.zeros([1,6,64,64])
        input = torch.cat([mu, logvar],dim=1)
        gq = GaussianQuant(dim=3, dim_idx=1, codebook_size=128)
        # quant on C dimension
        # codebook dim = 3, input dim = 6, so codebook number is 2
        zhat, log = gq(input)
        loss = log["kl-loss"]
        indices = log["indices"]
        print(indices.shape) # (B, codebook number, H, W)
        zhat2 = gq.indices_to_codes(indices)
        print(torch.sum(torch.abs(zhat - zhat2))) # 0

@lucidrains
Copy link
Owner

nice work! i think it may be out of scope for this repository, but i'll keep your trick at the back of my head for future projects

@lucidrains lucidrains closed this Dec 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants