Skip to content

Conversation

@gngdb
Copy link

@gngdb gngdb commented Jul 30, 2025

I really wanted to run the modded-nanogpt benchmark and other PyTorch training flows using Modula defined models so I wrote some code that walks the tree to build an nn.Sequential model while setting the target_norm in each Atomic module so the dualization still works. Example of converting the GPT model from the example notebook is here.

Conversion function is in modula/to_pytorch.py, example usage:

model = ModulaGPT(
    vocab_size=config.vocab_size,
    num_heads=config.n_head,
    d_embed=config.n_embd,
    d_query=config.n_embd // config.n_head,
    d_value=config.n_embd // config.n_head,
    num_blocks=config.n_layer,
    blocks_mass=config.blocks_mass,
    attention_scale=attention_scale,
    final_scale=final_scale,
    )
# this sets the target norms as attributes on the atomic modules
_ = get_leaf_target_norms(model, target_norm=1.0)
self.transformer = flash_sequentialise(model)

And then before running the update, call the dualize_gradients methods with a utility function:

def dualize(model):
    """
    For all atomic modules, we need to run the dualize method before running
    gradient descent.
    """
    for m in model.modules():
        if hasattr(m, 'dualize_gradients'):
            m.dualize_gradients()
...
# in training loop
dualize(raw_model)
optimizer.step()
scheduler.step()

Full training script in another branch to contain the mess here.

This PR includes:

  • Unit tests of all the PyTorch modules (in modula/torch_modules.py) to make sure they match the Jax modules (had to slightly change the Modula definition of Rope to make it match the one in modded-nanogpt because I wanted it to be consistent)
  • Some basic integration tests making sure the converted models have the same forward passes, included in the to_pytorch.py file so they can be run: python to_pytorch.py.
  • Two conversion functions:
    • to_pytorch.sequentialise: maps the jax module structure directly to a nn.Sequential and torch_modules.Parallel structure. It can copy all the parameters across as well if they're equipped to the jax modules beforehand using abstract.set_atomic_weights
    • to_pytorch.flash_sequentialise: does the same thing as sequentialise but replaces Attention blocks with a functionally identical block using flash attention. Internally, it wastefully instantiates a model without flash attention and loads the state dictionary from this model, this could be avoided with meta devices.

I was able to run the 22 minute modded-nanogpt record (I wanted a record relatively close to GPT2/llm.c). I didn't tune it properly, just tried a few runs. Best run is here, achieving 3.4187 using about 3.2B tokens. Runtime was 27.65 min on 8xH100. I was going to keep tuning but then I realised this is comparable to the Modula-only record in Training Transformers with Enforced Lipschitz Constants (although with a larger token budget):

Screenshot 2025-07-29 at 3 39 10 pm

I'm not sure if this tool is something the maintainers are interested in but I'd be happy to clean it up and run a few more experiments to close the gap on token budget in the experiment. I also only just realised the lipschitz-transformer code is up so I need to look there for the right settings for the 0.7B token budget run.

Other stuff:
- make rotary match modded-nanogpt
- optional automatically substitute in `scaled_dot_product_attention`
- tests for all the torch modules
- tests for the model conversion
@EIFY
Copy link

EIFY commented Jul 30, 2025

I was going to keep tuning but then I realised this is comparable to the Modula-only record in Training Transformers with Enforced Lipschitz Constants (although with a larger token budget)

My reading is that they still trained those models with PyTorch. See this folder in their repo: https://github.com/Arongil/lipschitz-transformers/tree/main/nanogpt

I have also done similar bridge but from JAX to JAX (lol) for scalability and support infra (data pipeline, logging, etc.) but it underperformed: #10 (comment)

Comment on lines +95 to +161
from modula.abstract import get_leaf_modules, get_leaf_target_norms

def test_dualize_consistency(module, grad_w, target_norm=1.0, rtol=1e-6):
"""
Test that get_unnormalized_dual and get_leaf_target_norms produce results
consistent with the actual dualize method.
Args:
module: A Module instance
grad_w: Weight gradient list
target_norm: Target norm to test with
rtol: Relative tolerance for comparison
Returns:
bool: True if consistent, False otherwise
"""
# Get results from actual dualize
actual_dual = module.dualize(grad_w, target_norm=target_norm)

# Get results from our functions
unnormalized_dual = get_unnormalized_dual(module, grad_w)
leaf_modules = get_leaf_modules(module)
target_norms = get_leaf_target_norms(module, target_norm=target_norm)

# Apply target norms to unnormalized dual
predicted_dual = []
weight_idx = 0

for leaf_module, leaf_target_norm in zip(leaf_modules, target_norms):
if isinstance(leaf_module, (Atom)): # Only atoms have weights
leaf_weights = unnormalized_dual[weight_idx:weight_idx + leaf_module.atoms]
# Apply the target norm
scaled_weights = [w * leaf_target_norm for w in leaf_weights]
predicted_dual.extend(scaled_weights)
weight_idx += leaf_module.atoms
# Bonds have no weights, so nothing to add to predicted_dual

# Compare actual vs predicted
if len(actual_dual) != len(predicted_dual):
print(f"Length mismatch: actual {len(actual_dual)}, predicted {len(predicted_dual)}")
return False

for i, (actual, predicted) in enumerate(zip(actual_dual, predicted_dual)):
if not jnp.allclose(actual, predicted, rtol=rtol):
print(f"Mismatch at weight {i}")
print(f"Actual shape: {actual.shape}, Predicted shape: {predicted.shape}")
print(f"Max difference: {jnp.max(jnp.abs(actual - predicted))}")
return False

print("✓ Dualize consistency test passed!")
return True

# Example usage:
def test_example():
"""Example test with a simple module"""
# Create a simple module
linear = Linear(fanout=4, fanin=3)
linear @= Linear(fanout=4, fanin=4) # Add another linear layer
linear @= Linear(fanout=2, fanin=4) # Add another linear layer

# Initialize weights and create some gradient
key = jax.random.PRNGKey(42)
weights = linear.initialize(key)
grad_w = [jax.random.normal(key, shape=w.shape) for w in weights]

# Test consistency
return test_dualize_consistency(linear, grad_w, target_norm=2.5)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops I didn't mean to leave this in, I was running some unrelated tests. I'll remove this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are a mess to be honest, lazy vibe coding. They do mostly work, at least.

super().__init__(num_embed, d_embed, padding_idx=padding_idx)
self.num_embed = num_embed
self.d_embed = d_embed
self.sensitivity = 1
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying not to include sensitivity attributes that don't do anything (the target norms are inherited from the Jax module and don't do anything after conversion).

@gngdb
Copy link
Author

gngdb commented Sep 10, 2025

Noticed that I was applying the momentum wrong in experiments. The output of the EMA is supposed to go into the dualize function, not the raw gradient. I didn't have any EMA at all. Edited the experiment to do this instead using this class: https://github.com/gngdb/modula/blob/pytorch_bridge_experiments_with_ema/modded-nanogpt/train_gpt2.py#L26-L71

Tried to get a better final loss on the same speedrun benchmark as above but only marginally better, from 3.418 to 3.395 validation loss. Log of that experiment is here. Tried varying learning rate but I didn't see better results and it was not stable at higher learning rates.

To get the final loss to match the speedrun result I have three ideas:

  1. Look at gradient norms and activations to pinpoint why it diverges at higher learning rates
  2. Change to a different base speedrun checkpoint that uses untied embeddings. The 22 minute record I based this on is designed for tied embeddings so this is likely confounding results. The 10 minute record is probably a better starting point but would require implementing squared relu, zero-init proj and QK-norm in the modula model and making sure it converts OK.
  3. I could look at the expected scale of the updates in a given speedrun config as a guide for how to set the mass parameters.


module = modula.atom.Linear(fanout=4, fanin=3)
module @= (modula.atom.Linear(fanout=2, fanin=4), modula.atom.Linear(fanout=2, fanin=4))
module @= modula.atom.Linear(fanout=2, fanin=4)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example module is wrong, because the hidden dimensions are mismatched. The MLP is constructed from the bottom to the top, meaning the input should be 4-dim.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the TupleModule of two modules are not concatenated

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah I never ran that example, that network is not runnable.

RuntimeError: mat1 and mat2 shapes cannot be multiplied (5x3 and 4x2)

Just updated to fix and test it runs. The tuple isn't concatenated because there is no bond module by default to concat tuples, it only happens in the AttentionQK module. I didn't want to define one just for this example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants