Skip to content

Conversation

@EIFY
Copy link

@EIFY EIFY commented Apr 17, 2025

includes necessary modules (LN, bias, scale, constant posemb) and notebook showing it working on MNIST. Initial tuning shows that momentum w/ dualize not quite as performant as Adam but

  1. That was a 10-minute effort
  2. Not sure how serious we should take MNIST performance after 1000 steps anyway

More of a demonstration but could be merged.

includes necessary modules (LN, bias, scale, constant posemb)
@EIFY EIFY mentioned this pull request Apr 17, 2025
EIFY added 3 commits April 16, 2025 20:36
In most cases image data has a channel dimension (that's usually 3)
So making the input shape NHWC makes the ViT more applicable
externally. Unexpectedly, bias alone stabilizes the model tho not
nearly as performant.
Just like not caching RoPE, it's more in the spirit of JAX to defer
constant materialization to forward pass and allows distributed
init. without triggering discallowed host-to-device transfer error.
@EIFY
Copy link
Author

EIFY commented May 4, 2025

I have ported this "dualized ViT" to Big Vision and started training a dualized ViT-S/16 on ImageNet-1k.
TL;DR: So far it severely underperforms the baseline but there are still a few knobs to turn.

Implementation: I made a custom branch of modula that dry-runs dualize recursively and caches the target_norm for each atom. I then export them along with the class name of the atom module and graft the behavior of dualize onto the Big Vision implementation by using optax.partition of the optax nightly (see the Big Vision branch that does this). So far the training loss looks like this:

Screenshot 2025-05-04 at 2 54 24 PM

wandb report

LR=0.05, WD=0.005 or 0.0001 (decoupled from LR), momentum w/ beta=0.95
No warm-up, cosine learning rate decay.
Notably l2_params starts out not too different from the baseline but grows quickly to 50x of that. My knee-jerk reaction to raise WD from 0.0001 to 0.005 however seems to make the model even worse:
Screenshot 2025-05-04 at 3 07 15 PM

I don't have great intuition here. Perhaps it still needs LR warm-up even though the model training is stable without it?

Notable architecture (or rather, just scaling) differences from the baseline:

  1. GELU is scaled by 1/1.1289 to keep the derivative <= 1.
  2. The dot product of the dot-product attention is scaled by 1/d instead of 1/sqrt(d) where d is the dimension of the attention head (same as μP). The final attention output is then scaled by 1/3.
  3. The residual branch is scaled by s = 1/(2 * depth) and the residual connection is normalized as x = (1-s) * x + s * y. This is more aggressive than the 1/sqrt(depth) residual branch scaling for ViT here

Optimizer differences from the "conventional" muon:

  1. Due to the different target_norm values we effectively have different LR for each layer.
  2. Instead of AdamW fallback we have momentum + L2-normalized update for bias (same as Embed) and Lion-like momentum + sign update for scale (following the derivation of muon and consider the scale of layernorm as the linear operator of diag(scale)).
  3. We don't clip the sqrt(fanout / fanin) factor from below as sqrt(max(1, fanout / fanin)). It's what it is.

If necessary it may be interesting to bisect these differences.

x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim]
x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim]

# Why is the order reversed!?
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this too and corrected it in #13. I guess it doesn't really matter for performance?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants