-
Notifications
You must be signed in to change notification settings - Fork 32
Modula composite module in jax -> nn.Sequential in PyTorch #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Other stuff: - make rotary match modded-nanogpt - optional automatically substitute in `scaled_dot_product_attention` - tests for all the torch modules - tests for the model conversion
My reading is that they still trained those models with PyTorch. See this folder in their repo: https://github.com/Arongil/lipschitz-transformers/tree/main/nanogpt I have also done similar bridge but from JAX to JAX (lol) for scalability and support infra (data pipeline, logging, etc.) but it underperformed: #10 (comment) |
| from modula.abstract import get_leaf_modules, get_leaf_target_norms | ||
|
|
||
| def test_dualize_consistency(module, grad_w, target_norm=1.0, rtol=1e-6): | ||
| """ | ||
| Test that get_unnormalized_dual and get_leaf_target_norms produce results | ||
| consistent with the actual dualize method. | ||
| Args: | ||
| module: A Module instance | ||
| grad_w: Weight gradient list | ||
| target_norm: Target norm to test with | ||
| rtol: Relative tolerance for comparison | ||
| Returns: | ||
| bool: True if consistent, False otherwise | ||
| """ | ||
| # Get results from actual dualize | ||
| actual_dual = module.dualize(grad_w, target_norm=target_norm) | ||
|
|
||
| # Get results from our functions | ||
| unnormalized_dual = get_unnormalized_dual(module, grad_w) | ||
| leaf_modules = get_leaf_modules(module) | ||
| target_norms = get_leaf_target_norms(module, target_norm=target_norm) | ||
|
|
||
| # Apply target norms to unnormalized dual | ||
| predicted_dual = [] | ||
| weight_idx = 0 | ||
|
|
||
| for leaf_module, leaf_target_norm in zip(leaf_modules, target_norms): | ||
| if isinstance(leaf_module, (Atom)): # Only atoms have weights | ||
| leaf_weights = unnormalized_dual[weight_idx:weight_idx + leaf_module.atoms] | ||
| # Apply the target norm | ||
| scaled_weights = [w * leaf_target_norm for w in leaf_weights] | ||
| predicted_dual.extend(scaled_weights) | ||
| weight_idx += leaf_module.atoms | ||
| # Bonds have no weights, so nothing to add to predicted_dual | ||
|
|
||
| # Compare actual vs predicted | ||
| if len(actual_dual) != len(predicted_dual): | ||
| print(f"Length mismatch: actual {len(actual_dual)}, predicted {len(predicted_dual)}") | ||
| return False | ||
|
|
||
| for i, (actual, predicted) in enumerate(zip(actual_dual, predicted_dual)): | ||
| if not jnp.allclose(actual, predicted, rtol=rtol): | ||
| print(f"Mismatch at weight {i}") | ||
| print(f"Actual shape: {actual.shape}, Predicted shape: {predicted.shape}") | ||
| print(f"Max difference: {jnp.max(jnp.abs(actual - predicted))}") | ||
| return False | ||
|
|
||
| print("✓ Dualize consistency test passed!") | ||
| return True | ||
|
|
||
| # Example usage: | ||
| def test_example(): | ||
| """Example test with a simple module""" | ||
| # Create a simple module | ||
| linear = Linear(fanout=4, fanin=3) | ||
| linear @= Linear(fanout=4, fanin=4) # Add another linear layer | ||
| linear @= Linear(fanout=2, fanin=4) # Add another linear layer | ||
|
|
||
| # Initialize weights and create some gradient | ||
| key = jax.random.PRNGKey(42) | ||
| weights = linear.initialize(key) | ||
| grad_w = [jax.random.normal(key, shape=w.shape) for w in weights] | ||
|
|
||
| # Test consistency | ||
| return test_dualize_consistency(linear, grad_w, target_norm=2.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops I didn't mean to leave this in, I was running some unrelated tests. I'll remove this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are a mess to be honest, lazy vibe coding. They do mostly work, at least.
| super().__init__(num_embed, d_embed, padding_idx=padding_idx) | ||
| self.num_embed = num_embed | ||
| self.d_embed = d_embed | ||
| self.sensitivity = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying not to include sensitivity attributes that don't do anything (the target norms are inherited from the Jax module and don't do anything after conversion).
|
Noticed that I was applying the momentum wrong in experiments. The output of the EMA is supposed to go into the dualize function, not the raw gradient. I didn't have any EMA at all. Edited the experiment to do this instead using this class: https://github.com/gngdb/modula/blob/pytorch_bridge_experiments_with_ema/modded-nanogpt/train_gpt2.py#L26-L71 Tried to get a better final loss on the same speedrun benchmark as above but only marginally better, from 3.418 to 3.395 validation loss. Log of that experiment is here. Tried varying learning rate but I didn't see better results and it was not stable at higher learning rates. To get the final loss to match the speedrun result I have three ideas:
|
modula/to_pytorch.py
Outdated
|
|
||
| module = modula.atom.Linear(fanout=4, fanin=3) | ||
| module @= (modula.atom.Linear(fanout=2, fanin=4), modula.atom.Linear(fanout=2, fanin=4)) | ||
| module @= modula.atom.Linear(fanout=2, fanin=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example module is wrong, because the hidden dimensions are mismatched. The MLP is constructed from the bottom to the top, meaning the input should be 4-dim.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the TupleModule of two modules are not concatenated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah I never ran that example, that network is not runnable.
RuntimeError: mat1 and mat2 shapes cannot be multiplied (5x3 and 4x2)
Just updated to fix and test it runs. The tuple isn't concatenated because there is no bond module by default to concat tuples, it only happens in the AttentionQK module. I didn't want to define one just for this example.
I really wanted to run the modded-nanogpt benchmark and other PyTorch training flows using Modula defined models so I wrote some code that walks the tree to build an
nn.Sequentialmodel while setting thetarget_normin each Atomic module so the dualization still works. Example of converting the GPT model from the example notebook is here.Conversion function is in
modula/to_pytorch.py, example usage:And then before running the update, call the
dualize_gradientsmethods with a utility function:Full training script in another branch to contain the mess here.
This PR includes:
modula/torch_modules.py) to make sure they match the Jax modules (had to slightly change the Modula definition of Rope to make it match the one in modded-nanogpt because I wanted it to be consistent)to_pytorch.pyfile so they can be run:python to_pytorch.py.to_pytorch.sequentialise: maps the jax module structure directly to ann.Sequentialandtorch_modules.Parallelstructure. It can copy all the parameters across as well if they're equipped to the jax modules beforehand usingabstract.set_atomic_weightsto_pytorch.flash_sequentialise: does the same thing assequentialisebut replaces Attention blocks with a functionally identical block using flash attention. Internally, it wastefully instantiates a model without flash attention and loads the state dictionary from this model, this could be avoided with meta devices.I was able to run the 22 minute modded-nanogpt record (I wanted a record relatively close to GPT2/llm.c). I didn't tune it properly, just tried a few runs. Best run is here, achieving 3.4187 using about 3.2B tokens. Runtime was 27.65 min on 8xH100. I was going to keep tuning but then I realised this is comparable to the Modula-only record in Training Transformers with Enforced Lipschitz Constants (although with a larger token budget):
I'm not sure if this tool is something the maintainers are interested in but I'd be happy to clean it up and run a few more experiments to close the gap on token budget in the experiment. I also only just realised the lipschitz-transformer code is up so I need to look there for the right settings for the 0.7B token budget run.