-
Notifications
You must be signed in to change notification settings - Fork 32
simple ViT implementation #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
includes necessary modules (LN, bias, scale, constant posemb)
In most cases image data has a channel dimension (that's usually 3) So making the input shape NHWC makes the ViT more applicable externally. Unexpectedly, bias alone stabilizes the model tho not nearly as performant.
Just like not caching RoPE, it's more in the spirit of JAX to defer constant materialization to forward pass and allows distributed init. without triggering discallowed host-to-device transfer error.
|
I have ported this "dualized ViT" to Big Vision and started training a dualized ViT-S/16 on ImageNet-1k. Implementation: I made a custom branch of modula that dry-runs
LR=0.05, WD=0.005 or 0.0001 (decoupled from LR), momentum w/ beta=0.95 I don't have great intuition here. Perhaps it still needs LR warm-up even though the model training is stable without it? Notable architecture (or rather, just scaling) differences from the baseline:
Optimizer differences from the "conventional" muon:
If necessary it may be interesting to bisect these differences. |
| x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim] | ||
| x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim] | ||
|
|
||
| # Why is the order reversed!? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed this too and corrected it in #13. I guess it doesn't really matter for performance?


includes necessary modules (LN, bias, scale, constant posemb) and notebook showing it working on MNIST. Initial tuning shows that momentum w/
dualizenot quite as performant as Adam butMore of a demonstration but could be merged.