From d405640ca9c2467b446caf3a3c512fba0b71d502 Mon Sep 17 00:00:00 2001 From: Gavia Gray Date: Tue, 29 Jul 2025 19:34:54 -0400 Subject: [PATCH 1/2] add utility to produce a PyTorch module from a modula composite Other stuff: - make rotary match modded-nanogpt - optional automatically substitute in `scaled_dot_product_attention` - tests for all the torch modules - tests for the model conversion --- modula/abstract.py | 134 +++++++++++ modula/atom.py | 71 +++++- modula/bond.py | 3 +- modula/test_torch_modules.py | 449 +++++++++++++++++++++++++++++++++++ modula/to_pytorch.py | 375 +++++++++++++++++++++++++++++ modula/torch_modules.py | 330 +++++++++++++++++++++++++ 6 files changed, 1360 insertions(+), 2 deletions(-) create mode 100644 modula/test_torch_modules.py create mode 100644 modula/to_pytorch.py create mode 100644 modula/torch_modules.py diff --git a/modula/abstract.py b/modula/abstract.py index ec9ca01..004eaec 100644 --- a/modula/abstract.py +++ b/modula/abstract.py @@ -204,3 +204,137 @@ def __init__(self, scalar): def forward(self, x, w): return x * self.sensitivity + +def get_leaf_modules(module): + """ + Walk through a module tree and return the leaf modules (Atom or Bond instances) + in the same order as the corresponding weights would be in the list returned by initialize(). + + Args: + module: A Module instance (typically CompositeModule at top level) + + Returns: + List of leaf modules (Atom or Bond instances) + """ + # Base case: if this is a leaf module (Atom or Bond) + if isinstance(module, (Atom, Bond)): + return [module] + + # If this is a CompositeModule + elif isinstance(module, CompositeModule): + m0, m1 = module.children + # Order matches initialize(): m0 weights first, then m1 weights + return get_leaf_modules(m0) + get_leaf_modules(m1) + + # If this is a TupleModule + elif isinstance(module, TupleModule): + leaf_modules = [] + # Order matches initialize(): iterate through children in order + for child in module.children: + leaf_modules.extend(get_leaf_modules(child)) + return leaf_modules + + # For any other Module type, assume no children or handle as needed + else: + return [] + +def get_leaf_target_norms(module, target_norm=1.0): + """ + Walk through a module tree the same way dualize() does and compute the target norm + that would be passed to each leaf module's dualize method, then store it in the module. + + Args: + module: A Module instance + target_norm: The target norm passed to this module's dualize method + + Returns: + List of target norms for leaf modules, in the same order as get_leaf_modules() + """ + # Base case: if this is a leaf module (Atom or Bond) + if isinstance(module, (Atom, Bond)): + return [target_norm] + + # If this is a CompositeModule + elif isinstance(module, CompositeModule): + if module.mass > 0: + m0, m1 = module.children + # Same logic as in CompositeModule.dualize() + m0.target_norm = target_norm * m0.mass / module.mass / m1.sensitivity + m1.target_norm = target_norm * m1.mass / module.mass + + # Recursively get target norms for children (order: m0 first, then m1) + return (get_leaf_target_norms(m0, m0.target_norm) + + get_leaf_target_norms(m1, m1.target_norm)) + else: + # When mass is 0, we still need to traverse to get the structure right + m0, m1 = module.children + return (get_leaf_target_norms(m0, 0.0) + + get_leaf_target_norms(m1, 0.0)) + + # If this is a TupleModule + elif isinstance(module, TupleModule): + if module.mass > 0: + target_norms = [] + # Same logic as in TupleModule.dualize() + for child in module.children: + child.target_norm = target_norm * child.mass / module.mass + target_norms.extend(get_leaf_target_norms(child, child.target_norm)) + return target_norms + else: + # When mass is 0, we still need to traverse to get the structure right + target_norms = [] + for child in module.children: + target_norms.extend(get_leaf_target_norms(child, 0.0)) + return target_norms + + # For any other Module type, assume no children + else: + return [] + +def traverse_forward_order(module, func): + """ + Traverse composite modules in the order that forward() will execute them, + applying func to each module. + + Args: + module: The module to traverse + func: Function to apply to each module. Should accept a module as argument. + """ + def _traverse(mod): + if isinstance(mod, CompositeModule): + # For composite modules, traverse m0 first, then m1 (execution order) + m0, m1 = mod.children + _traverse(m0) + _traverse(m1) + elif isinstance(mod, TupleModule): + # For tuple modules, traverse all children (they execute in parallel) + for child in mod.children: + _traverse(child) + + # Apply function to current module after traversing children + func(mod) + + _traverse(module) + +def set_atomic_weights(module, weights): + """ + Set weights as attributes on atomic modules by traversing in forward execution order. Assumes each atomic module has only one weight. + + Args: + module: The root module to traverse + weights: List of weights corresponding to atomic modules + """ + weight_index = [0] # Use list to make it mutable in closure + + def assign_weight(mod): + if isinstance(mod, Atom): + if weight_index[0] < len(weights): + mod.weight = weights[weight_index[0]] + weight_index[0] += 1 + + traverse_forward_order(module, assign_weight) + + # Validate that we used all weights + if weight_index[0] != len(weights): + raise ValueError(f"Number of weights ({len(weights)}) doesn't match number of atomic modules ({weight_index[0]})") + diff --git a/modula/atom.py b/modula/atom.py index c46855f..239d561 100644 --- a/modula/atom.py +++ b/modula/atom.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp -from modula.abstract import Atom +from modula.abstract import Atom, Bond, CompositeModule, TupleModule def orthogonalize(M): # six step Newton-Schulz by @YouJiacheng @@ -92,6 +92,73 @@ def dualize(self, grad_w, target_norm=1.0): if __name__ == "__main__": + from modula.abstract import get_leaf_modules, get_leaf_target_norms + + def test_dualize_consistency(module, grad_w, target_norm=1.0, rtol=1e-6): + """ + Test that get_unnormalized_dual and get_leaf_target_norms produce results + consistent with the actual dualize method. + + Args: + module: A Module instance + grad_w: Weight gradient list + target_norm: Target norm to test with + rtol: Relative tolerance for comparison + + Returns: + bool: True if consistent, False otherwise + """ + # Get results from actual dualize + actual_dual = module.dualize(grad_w, target_norm=target_norm) + + # Get results from our functions + unnormalized_dual = get_unnormalized_dual(module, grad_w) + leaf_modules = get_leaf_modules(module) + target_norms = get_leaf_target_norms(module, target_norm=target_norm) + + # Apply target norms to unnormalized dual + predicted_dual = [] + weight_idx = 0 + + for leaf_module, leaf_target_norm in zip(leaf_modules, target_norms): + if isinstance(leaf_module, (Atom)): # Only atoms have weights + leaf_weights = unnormalized_dual[weight_idx:weight_idx + leaf_module.atoms] + # Apply the target norm + scaled_weights = [w * leaf_target_norm for w in leaf_weights] + predicted_dual.extend(scaled_weights) + weight_idx += leaf_module.atoms + # Bonds have no weights, so nothing to add to predicted_dual + + # Compare actual vs predicted + if len(actual_dual) != len(predicted_dual): + print(f"Length mismatch: actual {len(actual_dual)}, predicted {len(predicted_dual)}") + return False + + for i, (actual, predicted) in enumerate(zip(actual_dual, predicted_dual)): + if not jnp.allclose(actual, predicted, rtol=rtol): + print(f"Mismatch at weight {i}") + print(f"Actual shape: {actual.shape}, Predicted shape: {predicted.shape}") + print(f"Max difference: {jnp.max(jnp.abs(actual - predicted))}") + return False + + print("✓ Dualize consistency test passed!") + return True + + # Example usage: + def test_example(): + """Example test with a simple module""" + # Create a simple module + linear = Linear(fanout=4, fanin=3) + linear @= Linear(fanout=4, fanin=4) # Add another linear layer + linear @= Linear(fanout=2, fanin=4) # Add another linear layer + + # Initialize weights and create some gradient + key = jax.random.PRNGKey(42) + weights = linear.initialize(key) + grad_w = [jax.random.normal(key, shape=w.shape) for w in weights] + + # Test consistency + return test_dualize_consistency(linear, grad_w, target_norm=2.5) key = jax.random.PRNGKey(0) @@ -116,3 +183,5 @@ def dualize(self, grad_w, target_norm=1.0): error_O = jnp.linalg.norm(O - U @ Vh) / jnp.linalg.norm(U @ Vh) print(f"relative error in M's SVD: {error_M}") print(f"relative error in O: {error_O}") + + test_example() diff --git a/modula/bond.py b/modula/bond.py index 7bdbf00..7e7561a 100644 --- a/modula/bond.py +++ b/modula/bond.py @@ -105,6 +105,7 @@ def __init__(self, d_head, base=10000): self.sensitivity = 1 # rope is an orthogonal transformation self.rope_dim = d_head // 2 + self.base = base self.inverse_frequencies = 1/base**(jnp.arange(self.rope_dim) / self.rope_dim) self.seq_len_cached = None self.sin_cached = None @@ -126,7 +127,7 @@ def rotate(self, x): x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim] x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim] - cos, sin = self.get_cached(seq_len) + sin, cos = self.get_cached(seq_len) y1 = cos * x1 + sin * x2 y2 = -sin * x1 + cos * x2 diff --git a/modula/test_torch_modules.py b/modula/test_torch_modules.py new file mode 100644 index 0000000..9a530f2 --- /dev/null +++ b/modula/test_torch_modules.py @@ -0,0 +1,449 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import jax +import jax.numpy as jnp + +import modula.atom +import modula.bond + +from torch_modules import (orthogonalize, Linear, Embed, ReLU, GeLU, SplitIntoHeads, MergeHeads, + AttentionQK, CausalMask, Softmax, ApplyAttentionScores, + Rope) + +def test_orthogonalize(): + """Test PyTorch implementation against JAX implementation""" + print("Testing PyTorch orthogonalize against JAX modula.atom.orthogonalize...") + + # Test cases with different shapes + test_cases = [ + (4, 4), # Square + (6, 3), # Tall + (3, 6), # Wide + (5, 5), # Another square + (8, 2), # Very tall + (2, 8), # Very wide + ] + + for i, (rows, cols) in enumerate(test_cases): + print(f"\nTest {i+1}: Matrix shape ({rows}, {cols})") + + # Create random matrix + np_matrix = np.random.randn(rows, cols).astype(np.float32) + + # Convert to respective frameworks + torch_matrix = torch.from_numpy(np_matrix) + jax_matrix = jnp.array(np_matrix) + + # Apply orthogonalization + torch_result = orthogonalize(torch_matrix) + jax_result = modula.atom.orthogonalize(jax_matrix) + + # Convert back to numpy for comparison + torch_result_np = torch_result.detach().numpy() + jax_result_np = np.array(jax_result) + + # Check if results are close + max_diff = np.max(np.abs(torch_result_np - jax_result_np)) + relative_error = max_diff / (np.max(np.abs(jax_result_np)) + 1e-8) + + print(f" Max absolute difference: {max_diff:.2e}") + print(f" Relative error: {relative_error:.2e}") + print(f" Results match (abs diff < 1e-6): {max_diff < 1e-6}") + + # Also check orthogonality of PyTorch result + if torch_result.shape[1] <= torch_result.shape[0]: + # Tall or square: check Q^T @ Q = I + product = torch_result.T @ torch_result + identity = torch.eye(torch_result.shape[1]) + ortho_error = torch.max(torch.abs(product - identity)).item() + else: + # Wide: check Q @ Q^T = I + product = torch_result @ torch_result.T + identity = torch.eye(torch_result.shape[0]) + ortho_error = torch.max(torch.abs(product - identity)).item() + + print(f" PyTorch result orthogonality error: {ortho_error:.2e}") + print(f" Is orthogonal (error < 1e-5): {ortho_error < 1e-5}") + + # Test with different dtypes + print("\nTest dtype consistency:") + np_matrix = np.random.randn(4, 3).astype(np.float64) + + # torch_matrix_f32 = torch.from_numpy(np_matrix.astype(np.float32)) + torch_matrix_f64 = torch.from_numpy(np_matrix.astype(np.float64)) + jax_matrix = jnp.array(np_matrix) + + # torch_result_f32 = orthogonalize(torch_matrix_f32) + torch_result_f64 = orthogonalize(torch_matrix_f64) + jax_result = modula.atom.orthogonalize(jax_matrix) + + # Compare f64 results (should be most accurate) + diff_f64 = np.max(np.abs(torch_result_f64.numpy() - np.array(jax_result))) + print(f" Float64 max difference: {diff_f64:.2e}") + print(f" Float64 results match: {diff_f64 < 1e-10}") + + # Test gradient compatibility (PyTorch-specific feature) + print("\nTest gradient compatibility:") + torch_matrix = torch.randn(4, 3, requires_grad=True) + torch_result = orthogonalize(torch_matrix) + loss = torch.sum(torch_result ** 2) + loss.backward() + + has_grad = torch_matrix.grad is not None + grad_finite = torch.all(torch.isfinite(torch_matrix.grad)) if has_grad else False + print(f" Gradients computed: {has_grad}") + print(f" Gradients finite: {grad_finite}") + + print("\nAll tests completed!") + +def test_linear_modules(): + """Test that JAX and PyTorch Linear modules produce identical results""" + # Set parameters + fanin, fanout = 8, 4 + batch_size = 2 + + # Create JAX module and initialize weights + jax_linear = modula.atom.Linear(fanout, fanin) # JAX module + key = jax.random.PRNGKey(42) + jax_weights = jax_linear.initialize(key) + + # Create PyTorch module and load the same weights + w = torch.tensor(jax_weights[0]) # Convert JAX weights to PyTorch tensor + torch_linear = Linear.from_modula(jax_linear, w) + + # Create identical input + np.random.seed(42) + input_np = np.random.randn(batch_size, fanin).astype(np.float32) + + jax_input = jnp.array(input_np) + torch_input = torch.from_numpy(input_np) + + # Forward pass + jax_output = jax_linear.forward(jax_input, jax_weights) + torch_output = torch_linear(torch_input) + + # Compare outputs + torch_output_np = torch_output.detach().numpy() + jax_output_np = np.array(jax_output) + + print(f"JAX output shape: {jax_output_np.shape}") + print(f"PyTorch output shape: {torch_output_np.shape}") + print(f"Max absolute difference: {np.max(np.abs(jax_output_np - torch_output_np))}") + + np.testing.assert_allclose(jax_output_np, torch_output_np, rtol=1e-6, atol=1e-6) + print("✓ Linear modules produce identical results") + +def test_embed_modules(): + """Test that JAX and PyTorch Embed modules produce identical results""" + # Set parameters + num_embed, d_embed = 10, 6 + batch_size = 3 + seq_len = 4 + + # Create JAX module and initialize weights + jax_embed = modula.atom.Embed(d_embed, num_embed) # JAX module + key = jax.random.PRNGKey(123) + jax_weights = jax_embed.initialize(key) + + # Create PyTorch module and load the same weights + w = torch.tensor(jax_weights[0]) # Convert JAX weights to PyTorch tensor + torch_embed = Embed.from_modula(jax_embed, w) + + # Create identical input (indices) + np.random.seed(123) + input_indices = np.random.randint(0, num_embed, size=(batch_size, seq_len)) + + jax_input = jnp.array(input_indices) + torch_input = torch.from_numpy(input_indices) + + # Forward pass + jax_output = jax_embed.forward(jax_input, jax_weights) + torch_output = torch_embed(torch_input) + + # Compare outputs + torch_output_np = torch_output.detach().numpy() + jax_output_np = np.array(jax_output) + + print(f"JAX output shape: {jax_output_np.shape}") + print(f"PyTorch output shape: {torch_output_np.shape}") + print(f"Max absolute difference: {np.max(np.abs(jax_output_np - torch_output_np))}") + + np.testing.assert_allclose(jax_output_np, torch_output_np, rtol=1e-6, atol=1e-6) + print("✓ Embed modules produce identical results") + +def test_activations(): + # Test data + x = [-2.0, -1.0, 0.0, 1.0, 2.0] + + # PyTorch + torch_x = torch.tensor(x) + torch_relu = ReLU()(torch_x).numpy() + torch_gelu = GeLU()(torch_x).numpy() + + # JAX (assuming modula.bond.ReLU and modula.bond.GeLU are imported) + jax_x = jnp.array(x) + jax_relu = np.array(modula.bond.ReLU().forward(jax_x, None)) + jax_gelu = np.array(modula.bond.GeLU().forward(jax_x, None)) + + # Compare + np.testing.assert_allclose(torch_relu, jax_relu, rtol=1e-6) + np.testing.assert_allclose(torch_gelu, jax_gelu, rtol=1e-4) + + print("✓ ReLU matches PyTorch") + print("✓ GELU matches PyTorch") + +def test_split_and_merge_heads(): + # Test parameters + batch_size = 2 + sequence_length = 8 + embed_dim = 64 + num_heads = 8 + + # Create test input (same for both versions) + np.random.seed(42) + input_data = np.random.randn(batch_size, sequence_length, embed_dim).astype(np.float32) + + print("Testing SplitIntoHeads...") + + # JAX version - Split + jax_split = modula.bond.SplitIntoHeads(num_heads) + jax_input = jnp.array(input_data) + jax_split_output = jax_split.forward(jax_input, None) + + # PyTorch version - Split + torch_split = SplitIntoHeads(num_heads) + torch_input = torch.from_numpy(input_data) + torch_split_output = torch_split.forward(torch_input) + + # Compare split outputs + jax_split_np = np.array(jax_split_output) + torch_split_np = torch_split_output.detach().numpy() + + print(f"Split - JAX output shape: {jax_split_np.shape}") + print(f"Split - PyTorch output shape: {torch_split_np.shape}") + + split_close = np.allclose(jax_split_np, torch_split_np, atol=1e-6) + print(f"Split outputs are numerically close: {split_close}") + + expected_split_shape = (batch_size, num_heads, sequence_length, embed_dim // num_heads) + assert jax_split_np.shape == expected_split_shape, "JAX split shape mismatch" + assert torch_split_np.shape == expected_split_shape, "PyTorch split shape mismatch" + assert split_close, "Split outputs are not numerically equivalent" + + print("✅ SplitIntoHeads tests passed!") + + print("\nTesting MergeHeads...") + + # JAX version - Merge + jax_merge = modula.bond.MergeHeads() + jax_merge_output = jax_merge.forward(jax_split_output, None) + + # PyTorch version - Merge + torch_merge = MergeHeads() + torch_merge_output = torch_merge.forward(torch_split_output) + + # Compare merge outputs + jax_merge_np = np.array(jax_merge_output) + torch_merge_np = torch_merge_output.detach().numpy() + + print(f"Merge - JAX output shape: {jax_merge_np.shape}") + print(f"Merge - PyTorch output shape: {torch_merge_np.shape}") + + merge_close = np.allclose(jax_merge_np, torch_merge_np, atol=1e-6) + print(f"Merge outputs are numerically close: {merge_close}") + + expected_merge_shape = (batch_size, sequence_length, embed_dim) + assert jax_merge_np.shape == expected_merge_shape, f"JAX merge shape mismatch" + assert torch_merge_np.shape == expected_merge_shape, f"PyTorch merge shape mismatch" + assert merge_close, "Merge outputs are not numerically equivalent" + + # Test round-trip: original -> split -> merge should equal original + original_recovered_jax = np.allclose(np.array(jax_input), jax_merge_np, atol=1e-6) + original_recovered_torch = np.allclose(input_data, torch_merge_np, atol=1e-6) + + print(f"JAX round-trip recovery: {original_recovered_jax}") + print(f"PyTorch round-trip recovery: {original_recovered_torch}") + + assert original_recovered_jax, "JAX round-trip failed" + assert original_recovered_torch, "PyTorch round-trip failed" + + print("✅ MergeHeads tests passed!") + print("✅ All tests passed!") + +def test_attention_components(): + # Test parameters + batch_size = 2 + num_heads = 4 + seq_len = 6 + d_query = 8 + scale = 2.0 + + # Create test inputs + np.random.seed(42) + q_data = np.random.randn(batch_size, num_heads, seq_len, d_query).astype(np.float32) + k_data = np.random.randn(batch_size, num_heads, seq_len, d_query).astype(np.float32) + v_data = np.random.randn(batch_size, num_heads, seq_len, d_query).astype(np.float32) + + print("Testing AttentionQK...") + + # Test AttentionQK + jax_qk = modula.bond.AttentionQK() + jax_q = jnp.array(q_data) + jax_k = jnp.array(k_data) + jax_scores = jax_qk.forward((jax_q, jax_k), None) + + torch_qk = AttentionQK() + torch_q = torch.from_numpy(q_data) + torch_k = torch.from_numpy(k_data) + torch_scores = torch_qk.forward((torch_q, torch_k)) + + jax_scores_np = np.array(jax_scores) + torch_scores_np = torch_scores.detach().numpy() + + print(f"QK - JAX shape: {jax_scores_np.shape}, PyTorch shape: {torch_scores_np.shape}") + qk_close = np.allclose(jax_scores_np, torch_scores_np, atol=1e-6) + print(f"QK outputs close: {qk_close}") + assert qk_close, "AttentionQK outputs don't match" + + print("✅ AttentionQK tests passed!") + + print("\nTesting CausalMask...") + + # Test CausalMask + jax_mask = modula.bond.CausalMask() + jax_masked = jax_mask.forward(jax_scores, None) + + torch_mask = CausalMask() + torch_masked = torch_mask.forward(torch_scores) + + jax_masked_np = np.array(jax_masked) + torch_masked_np = torch_masked.detach().numpy() + + print(f"Mask - JAX shape: {jax_masked_np.shape}, PyTorch shape: {torch_masked_np.shape}") + + # For masked values, check that -inf values are in the same positions + jax_is_neginf = jax_masked_np == -np.inf + torch_is_neginf = torch_masked_np == -np.inf + mask_positions_match = np.array_equal(jax_is_neginf, torch_is_neginf) + + # For non-masked values, check they're numerically close + non_masked_close = np.allclose(jax_masked_np[~jax_is_neginf], torch_masked_np[~torch_is_neginf], atol=1e-6) + + print(f"Mask positions match: {mask_positions_match}") + print(f"Non-masked values close: {non_masked_close}") + assert mask_positions_match and non_masked_close, "CausalMask outputs don't match" + + print("✅ CausalMask tests passed!") + + print("\nTesting Softmax...") + + # Test Softmax + jax_softmax = modula.bond.Softmax(scale) + jax_soft_out = jax_softmax.forward(jax_masked, None) + + torch_softmax = Softmax(scale) + torch_soft_out = torch_softmax.forward(torch_masked) + + jax_soft_np = np.array(jax_soft_out) + torch_soft_np = torch_soft_out.detach().numpy() + + print(f"Softmax - JAX shape: {jax_soft_np.shape}, PyTorch shape: {torch_soft_np.shape}") + softmax_close = np.allclose(jax_soft_np, torch_soft_np, atol=1e-6) + print(f"Softmax outputs close: {softmax_close}") + assert softmax_close, "Softmax outputs don't match" + + print("✅ Softmax tests passed!") + + print("\nTesting ApplyAttentionScores...") + + # Test ApplyAttentionScores + jax_apply = modula.bond.ApplyAttentionScores() + jax_v = jnp.array(v_data) + jax_final = jax_apply.forward((jax_v, jax_soft_out), None) + + torch_apply = ApplyAttentionScores() + torch_v = torch.from_numpy(v_data) + torch_final = torch_apply.forward((torch_v, torch_soft_out)) + + jax_final_np = np.array(jax_final) + torch_final_np = torch_final.detach().numpy() + + print(f"Apply - JAX shape: {jax_final_np.shape}, PyTorch shape: {torch_final_np.shape}") + apply_close = np.allclose(jax_final_np, torch_final_np, atol=1e-6) + print(f"Apply outputs close: {apply_close}") + assert apply_close, "ApplyAttentionScores outputs don't match" + + print("✅ ApplyAttentionScores tests passed!") + print("✅ All attention component tests passed!") + +def test_rope_equivalence(): + # Test parameters + batch_size = 2 + n_heads = 8 + seq_len = 16 + d_head = 64 + base = 10000 + + # Create test input + np.random.seed(42) + q_np = np.random.randn(batch_size, n_heads, seq_len, d_head).astype(np.float32) + k_np = np.random.randn(batch_size, n_heads, seq_len, d_head).astype(np.float32) + + # Convert to respective frameworks + q_torch = torch.from_numpy(q_np) + k_torch = torch.from_numpy(k_np) + q_jax = jnp.array(q_np) + k_jax = jnp.array(k_np) + + # Initialize models + rope_jax = modula.bond.Rope(d_head, base=base) + rope_torch = Rope(d_head, base=base) + + # Forward pass + with torch.no_grad(): + q_out_torch, k_out_torch = rope_torch((q_torch, k_torch)) + + q_out_jax, k_out_jax = rope_jax.forward((q_jax, k_jax), None) # w=None since it's not used + + sin_jax, cos_jax = rope_jax.get_cached(seq_len) + + # Extract cached values + sin_torch = rope_torch.sin_cached.numpy() + cos_torch = rope_torch.cos_cached.numpy() + + # Check if base values match + sin_match = np.allclose(sin_torch, sin_jax, rtol=1e-5, atol=1e-6) + cos_match = np.allclose(cos_torch, cos_jax, rtol=1e-5, atol=1e-6) + + print(f"Sin values match: {sin_match}") + print(f"Cos values match: {cos_match}") + + # Convert to numpy for comparison + q_out_torch_np = q_out_torch.numpy() + k_out_torch_np = k_out_torch.numpy() + q_out_jax_np = np.array(q_out_jax) + k_out_jax_np = np.array(k_out_jax) + + # Check equivalence + q_close = np.allclose(q_out_torch_np, q_out_jax_np, rtol=1e-5, atol=1e-6) + k_close = np.allclose(k_out_torch_np, k_out_jax_np, rtol=1e-5, atol=1e-6) + + print(f"Q outputs match: {q_close}") + print(f"K outputs match: {k_close}") + print(f"Max Q difference: {np.max(np.abs(q_out_torch_np - q_out_jax_np))}") + print(f"Max K difference: {np.max(np.abs(k_out_torch_np - k_out_jax_np))}") + + assert q_close and k_close, "Rope implementations do not match!" + print("✅ All tests passed! Both implementations produce identical outputs.") + + + +test_orthogonalize() +test_linear_modules() +test_embed_modules() +test_activations() +test_split_and_merge_heads() +test_attention_components() +test_rope_equivalence() diff --git a/modula/to_pytorch.py b/modula/to_pytorch.py new file mode 100644 index 0000000..005b881 --- /dev/null +++ b/modula/to_pytorch.py @@ -0,0 +1,375 @@ +import jax +import torch +import torch.nn as nn + +import modula.abstract +import modula.atom +import modula.bond + +import modula.torch_modules as torch_modules + +substitute = ( + modula.atom.Linear, torch_modules.Linear, + modula.atom.Embed, torch_modules.Embed, + modula.bond.ReLU, nn.ReLU, + modula.bond.GeLU, torch_modules.GeLU, + modula.bond.SplitIntoHeads, torch_modules.SplitIntoHeads, + modula.bond.MergeHeads, torch_modules.MergeHeads, + modula.bond.AttentionQK, torch_modules.AttentionQK, + modula.bond.CausalMask, torch_modules.CausalMask, + modula.bond.Softmax, torch_modules.Softmax, + modula.bond.ApplyAttentionScores, torch_modules.ApplyAttentionScores, + modula.bond.Rope, torch_modules.Rope, + modula.abstract.Mul, torch_modules.Mul, + modula.abstract.Add, torch_modules.Add, + modula.abstract.Identity, torch_modules.Identity, +) + +def get_structure_signature(module): + """ + Get a hashable signature representing the module's structure. + + Returns: + tuple: A nested tuple representing the structure + """ + if isinstance(module, modula.abstract.Atom): + return ('Atom', type(module).__name__) + + if isinstance(module, modula.abstract.Bond): + if isinstance(module, modula.abstract.Mul): + # Include scalar for Mul since it affects structure + return ('Bond', type(module).__name__, module.sensitivity) + return ('Bond', type(module).__name__) + + if isinstance(module, modula.abstract.CompositeModule): + m0, m1 = module.children + return ('Composite', get_structure_signature(m0), get_structure_signature(m1)) + + if isinstance(module, modula.abstract.TupleModule): + child_sigs = tuple(get_structure_signature(child) for child in module.children) + return ('Tuple', child_sigs) + + return ('Unknown', type(module).__name__) + +# Usage example: +def modules_match_by_signature(module1, module2): + """Fast structural comparison using signatures.""" + return get_structure_signature(module1) == get_structure_signature(module2) + +def convert_modula_to_torch(modula_module, substitute_map=substitute): + """ + Convert a modula module to its PyTorch equivalent using the substitution map. + + Args: + modula_module: The modula module to convert + substitute_map: Tuple of (modula_class, torch_class) pairs + + Returns: + Corresponding PyTorch module + + Raises: + ValueError: If no corresponding PyTorch module is found + """ + if isinstance(modula_module, nn.Module): + # If it's already a PyTorch module, return it as is + return modula_module + # Create a dictionary from the tuple for easier lookup + substitution_dict = {} + for i in range(0, len(substitute_map), 2): + modula_class = substitute_map[i] + torch_class = substitute_map[i + 1] + substitution_dict[modula_class] = torch_class + + # Find the corresponding torch class + modula_type = type(modula_module) + if modula_type not in substitution_dict: + raise ValueError(f"No PyTorch equivalent found for {modula_type}") + + torch_class = substitution_dict[modula_type] + + # Convert using the from_modula static method + if hasattr(torch_class, 'from_modula'): + # try to access a `weight` attribute if it exists + if hasattr(modula_module, 'weight'): + weight = torch.tensor(modula_module.weight, dtype=torch.float32) + return torch_class.from_modula(modula_module, weight) + else: + return torch_class.from_modula(modula_module) + else: # If no from_modula method, just instantiate the class + return torch_class() # it won't have arguments if it doesn't have from_modula + +# Reference Attetnion module for detection +def Attention(num_heads, d_embed, d_query, d_value, attention_scale): + """Multi-head attention""" + + # For keys, queries, and values we add a heads dimension. For the out projection, we remove heads. + # Remember modules compose right-to-left, and the order is modula.atom.Linear(d_out, d_in)! And @ means compose. + Q = modula.bond.SplitIntoHeads(num_heads) @ modula.atom.Linear(num_heads * d_query, d_embed) + K = modula.bond.SplitIntoHeads(num_heads) @ modula.atom.Linear(num_heads * d_query, d_embed) + V = modula.bond.SplitIntoHeads(num_heads) @ modula.atom.Linear(num_heads * d_value, d_embed) + W = modula.atom.Linear(d_embed, num_heads * d_value) @ modula.bond.MergeHeads() + + # Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose). + AttentionScores = modula.bond.Softmax(attention_scale) @ modula.bond.CausalMask() @ modula.bond.AttentionQK() @ modula.bond.Rope(d_query) @ (Q, K) + + # Read right-to-left: apply attention scores, multiply by 1/3 to fix the sensitivity to 1, project back to d_embed. + return W @ (1/3 * modula.bond.ApplyAttentionScores()) @ (V, AttentionScores) + +def extract_attention_parameters(module): + """ + Extract attention parameters from an Attention module structure. + + Returns: + dict: Dictionary with keys 'num_heads', 'd_embed', 'd_query', 'd_value', 'attention_scale' + or None if the structure doesn't match expected Attention pattern + """ + # First verify this looks like an attention module by checking high-level structure + if not isinstance(module, modula.abstract.CompositeModule): + return None + + # Walk through and collect all relevant modules + split_heads_modules = [] + linear_modules = [] + softmax_modules = [] + + def collect_modules(mod): + # Import the actual module classes - adjust these imports based on your actual module structure + if hasattr(mod, '__class__'): + class_name = mod.__class__.__name__ + if class_name == 'SplitIntoHeads': + split_heads_modules.append(mod) + elif class_name == 'Linear': + linear_modules.append(mod) + elif class_name == 'Softmax': + softmax_modules.append(mod) + + modula.abstract.traverse_forward_order(module, collect_modules) + + # Verify we have the expected number of modules + if len(split_heads_modules) < 1 or len(linear_modules) < 4 or len(softmax_modules) < 1: + return None + + try: + # Extract num_heads from first SplitIntoHeads + num_heads = split_heads_modules[0].num_heads + + # Extract attention_scale from Softmax + attention_scale = softmax_modules[0].sensitivity # Assuming it's stored as self.scale + + # For linear modules, based on the structure: + # - First Linear encountered should be V (due to tuple ordering): fanout = num_heads * d_value, fanin = d_embed + # - We need to find Q, K, V, and W linear modules + + # The V linear module (first one encountered) + v_linear = linear_modules[0] + d_embed = v_linear.fanin + num_heads_times_d_value = v_linear.fanout + d_value = num_heads_times_d_value // num_heads + + # Find Q or K linear module to get d_query + # Q and K should have the same fanout = num_heads * d_query + # We can identify them as having fanin = d_embed and fanout != num_heads * d_value + qk_candidates = [lin for lin in linear_modules[1:] + if lin.fanin == d_embed] + + if len(qk_candidates) < 2: # Should have both Q and K + print("Not enough Q/K candidates found.") + return None + + num_heads_times_d_query = qk_candidates[0].fanout + d_query = num_heads_times_d_query // num_heads + + # Verify consistency + if (num_heads_times_d_query != qk_candidates[1].fanout or + d_value * num_heads != num_heads_times_d_value or + d_query * num_heads != num_heads_times_d_query): + print("Inconsistent dimensions found in attention module.") + return None + + return { + 'num_heads': num_heads, + 'd_embed': d_embed, + 'd_query': d_query, + 'd_value': d_value, + 'attention_scale': attention_scale + } + + except (AttributeError, ZeroDivisionError, IndexError): + raise + return None + +def sequentialise(module, substitute_flash=False, verbose=False): + modules = [] + + if substitute_flash: + attention_signature = get_structure_signature(Attention(num_heads=1, d_embed=1, d_query=1, d_value=1, attention_scale=1.0)) + + def _traverse_tuple(mod): + return torch_modules.Parallel([sequentialise(child, substitute_flash=substitute_flash) for child in mod.children]) + + def _traverse(mod): + if substitute_flash: + if verbose: + print(f"Traversing module: {hash(get_structure_signature(mod))} - {get_structure_signature(mod)}") + if get_structure_signature(mod) == attention_signature: + # If the module matches the attention signature, convert it to a torch attention module + mod = torch_modules.FlashAttention(**extract_attention_parameters(mod)) + for m in mod: + modules.append(m) + mod = None + if isinstance(mod, modula.abstract.CompositeModule): + # For composite modules, traverse m0 first, then m1 (execution order) + m0, m1 = mod.children + _traverse(m0) + _traverse(m1) + elif isinstance(mod, modula.abstract.TupleModule): + # For tuple modules, traverse all children (they execute in parallel) + mod = _traverse_tuple(mod) + + # Apply function to current module after traversing children + if not any((isinstance(mod, cls) for cls in [modula.abstract.CompositeModule, modula.abstract.TupleModule])): + if mod is not None: + modules.append(mod) + + _traverse(module) + + # convert modula.atom and modula.bond modules to torch modules + modules = [convert_modula_to_torch(m) for m in modules] + seq_module = nn.Sequential(*modules) + return seq_module + +def flash_sequentialise(module): + state_dict = sequentialise(module, substitute_flash=False).state_dict() + flash_module = sequentialise(module, substitute_flash=True, verbose=False) + flash_module.load_state_dict(state_dict) + return flash_module + +if __name__ == "__main__": + import numpy as np + key = jax.random.PRNGKey(0) + torch.manual_seed(0) # For reproducibility in PyTorch + + # Example usage + + module = modula.atom.Linear(fanout=4, fanin=3) + module @= (modula.atom.Linear(fanout=2, fanin=4), modula.atom.Linear(fanout=2, fanin=4)) + module @= modula.atom.Linear(fanout=2, fanin=4) + + print(sequentialise(module)) + + attention = Attention(num_heads=2, d_embed=8, d_query=4, d_value=4, attention_scale=1.0) + weights = attention.initialize(key) + + # equip attention modules with target norms and weights + _ = modula.abstract.get_leaf_target_norms(attention, target_norm=1.0) + modula.abstract.set_atomic_weights(attention, weights) + torch_attention = sequentialise(attention) + print(torch_attention) + + # Check they are equivalent + x = torch.randn(2, 10, 8) # batch size 2, sequence length 10, embedding size 8 + modula_out = attention(x.numpy(), weights) + torch_out = torch_attention(x) + rtol = 1e-3 # relative tolerance for comparison + np.testing.assert_allclose(modula_out, torch_out.detach().numpy(), rtol=rtol) + print(f"Attention conversion successful and outputs match to within {rtol} relative tolerance.") + + + print([n for n, v in torch_attention.state_dict().items()]) # Print the names of the parameters in the torch attention module + + + # flash = torch_modules.FlashAttention(num_heads=2, d_embed=8, d_query=4, d_value=4, attention_scale=1.0) + flash = sequentialise(attention, substitute_flash=True) + print([n for n, v in flash.state_dict().items()]) # Print the names of the parameters in the flash attention module + + flash.load_state_dict(torch_attention.state_dict()) # Load weights from the previous attention module + print(flash) + + # test the Q, K and V modules are OK + v, _v = torch_attention[0][0][0], flash[0][0][0] + q, _q = torch_attention[0][1][0][0][0], flash[0][1][0][0][0] + k, _k = torch_attention[0][1][0][1][0], flash[0][1][0][1][0] + proj, _proj = torch_attention[4], flash[4] + rtol = 1e-3 # relative tolerance for comparison + np.testing.assert_allclose(v.weight.detach().numpy(), _v.weight.detach().numpy(), rtol=rtol) + np.testing.assert_allclose(q.weight.detach().numpy(), _q.weight.detach().numpy(), rtol=rtol) + np.testing.assert_allclose(k.weight.detach().numpy(), _k.weight.detach().numpy(), rtol=rtol) + np.testing.assert_allclose(proj.weight.detach().numpy(), flash[4].weight.detach().numpy(), rtol=rtol) + + # Check they are equivalent + torch_out_flash = flash(x) + np.testing.assert_allclose(torch_out.detach().numpy(), torch_out_flash.detach().numpy(), rtol=rtol) + print(f"Flash attention conversion successful and outputs match to within {rtol} relative tolerance.") + + # Do the same for a complete model + def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0): + # Set embed to have mass 1. This controls the proportion of feature learning that it contributes to the whole network. + embed = modula.atom.Embed(d_embed, vocab_size) + embed.tare() + + # Let's create attention and MLP layers. + att = Attention(num_heads, d_embed, d_query, d_value, attention_scale) + print(f"att hash = {hash(get_structure_signature(att))}") + mlp = modula.atom.Linear(d_embed, 4*d_embed) @ modula.bond.GeLU() @ modula.atom.Linear(4*d_embed, d_embed) + + # For our residual connections, L = 2*num_blocks because each block has two residual connections. + att_block = (1-1/(2*num_blocks)) * modula.abstract.Identity() + 1/(2*num_blocks) * att + mlp_block = (1-1/(2*num_blocks)) * modula.abstract.Identity() + 1/(2*num_blocks) * mlp + + # We can use powers of a module to compose it with itself many times! + blocks = (mlp_block @ att_block) ** num_blocks + + # Set all transformer blocks to have mass 5 (by default). + # So 5/7 of the change in the network output is due to the blocks, + # and 2/7 of the change in output is due to the embedding and out projection. + blocks.tare(absolute=blocks_mass) + + out = final_scale * modula.atom.Linear(vocab_size, d_embed) + + return out @ blocks @ embed + + vocab_size = 65 + num_heads = 4 + d_embed = 128 + d_query = 32 + d_value = 32 + num_blocks = 4 + attention_scale = 1 + final_scale = 1 + + model = GPT( + vocab_size=vocab_size, + num_heads=num_heads, + d_embed=d_embed, + d_query=d_query, + d_value=d_value, + num_blocks=num_blocks, + attention_scale=attention_scale, + final_scale=final_scale, + ) + + weights = model.initialize(key) + _ = modula.abstract.get_leaf_target_norms(model, target_norm=1.0) + modula.abstract.set_atomic_weights(model, weights) + + torch_model = sequentialise(model) + print(torch_model) + + # Check they are equivalent + x = torch.randint(0, vocab_size, (2, 10)) # batch size 2, sequence length 10 + modula_out = model(x.numpy(), weights) + torch_out = torch_model(x) + + rtol = 1e-3 # relative tolerance for comparison + np.testing.assert_allclose(modula_out, torch_out.detach().numpy(), rtol=rtol) + print(f"GPT conversion successful and outputs match to within {rtol} relative tolerance.") + + # Test we can convert the model to FlashAttention + flash_model = flash_sequentialise(model) + print(flash_model) + + # Check they are equivalent + torch_out_flash = flash_model(x) + np.testing.assert_allclose(torch_out.detach().numpy(), torch_out_flash.detach().numpy(), rtol=rtol) + print(f"Flash GPT conversion successful and outputs match to within {rtol} relative tolerance.") + diff --git a/modula/torch_modules.py b/modula/torch_modules.py new file mode 100644 index 0000000..4c6d206 --- /dev/null +++ b/modula/torch_modules.py @@ -0,0 +1,330 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def orthogonalize(M): + # six step Newton-Schulz by @YouJiacheng + # coefficients from: https://twitter.com/YouJiacheng/status/1893704552689303901 + # found by optimization: https://gist.github.com/YouJiacheng/393c90cbdc23b09d5688815ba382288b/5bff1f7781cf7d062a155eecd2f13075756482ae + # the idea of stability loss was from @leloykun + + abc_list = [ + (3955/1024, -8306/1024, 5008/1024), + (3735/1024, -6681/1024, 3463/1024), + (3799/1024, -6499/1024, 3211/1024), + (4019/1024, -6385/1024, 2906/1024), + (2677/1024, -3029/1024, 1162/1024), + (2172/1024, -1833/1024, 682/1024) + ] + + transpose = M.shape[1] > M.shape[0] + if transpose: + M = M.T + M = M / torch.linalg.norm(M) + for a, b, c in abc_list: + A = M.T @ M + I = torch.eye(A.shape[0], device=M.device, dtype=M.dtype) + M = M @ (a * I + b * A + c * A @ A) + if transpose: + M = M.T + return M + + +# Atomic +class Linear(nn.Linear): + def __init__(self, fanin, fanout, target_norm=1.0): + super().__init__(fanin, fanout, bias=False) + self.fanin = fanin + self.fanout = fanout + self.register_buffer('target_norm', torch.tensor(target_norm, dtype=torch.float)) + + # Initialize with orthogonal weights + self._initialize_weights() + + def _initialize_weights(self): + with torch.no_grad(): + weight = torch.randn(self.fanout, self.fanin) + weight = orthogonalize(weight) * torch.sqrt(torch.tensor(self.fanout / self.fanin)) + self.weight.copy_(weight) + + def project_weights(self): + """Project weights to the constraint manifold""" + with torch.no_grad(): + weight = orthogonalize(self.weight) * torch.sqrt(torch.tensor(self.fanout / self.fanin)) + self.weight.copy_(weight) + + def dualize_gradients(self): + """Apply dualization to gradients""" + if self.weight.grad is not None: + with torch.no_grad(): + grad = self.weight.grad + d_weight = orthogonalize(grad) * torch.sqrt(torch.tensor(self.fanout / self.fanin)) * self.target_norm + self.weight.grad.copy_(d_weight) + + @staticmethod + def from_modula(m, w=None): + """Convert from modula.atom.Linear""" + with torch.no_grad(): + linear = Linear(m.fanin, m.fanout, getattr(m, 'target_norm', 1.0)) + if w is not None: + linear.weight.copy_(w) + return linear + + def __repr__(self): + return f"Linear(fanin={self.fanin}, fanout={self.fanout}, target_norm={self.target_norm})" + +class Embed(nn.Embedding): + def __init__(self, num_embed, d_embed, target_norm, padding_idx=None): + super().__init__(num_embed, d_embed, padding_idx=padding_idx) + self.num_embed = num_embed + self.d_embed = d_embed + self.sensitivity = 1 + self.register_buffer('target_norm', torch.tensor(target_norm, dtype=torch.float)) + + # Initialize with normalized weights + self._initialize_weights() + + def _initialize_weights(self): + with torch.no_grad(): + weight = torch.randn(self.num_embed, self.d_embed) + weight = F.normalize(weight, p=2, dim=1) * torch.sqrt(torch.tensor(self.d_embed, dtype=torch.float)) + self.weight.copy_(weight) + + def project_weights(self): + """Project weights to the constraint manifold""" + with torch.no_grad(): + weight = F.normalize(self.weight, p=2, dim=1) * torch.sqrt(torch.tensor(self.d_embed, dtype=torch.float)) + self.weight.copy_(weight) + + def dualize_gradients(self): + """Apply dualization to gradients""" + if self.weight.grad is not None: + with torch.no_grad(): + grad = self.weight.grad + # Normalize each embedding vector's gradient + grad_norm = torch.norm(grad, p=2, dim=1, keepdim=True) + # Handle zero gradients to avoid NaN + grad_norm = torch.where(grad_norm == 0, torch.ones_like(grad_norm), grad_norm) + d_weight = grad / grad_norm * torch.sqrt(torch.tensor(self.d_embed, dtype=torch.float)) * self.target_norm + # Handle any remaining NaN values + d_weight = torch.nan_to_num(d_weight) + self.weight.grad.copy_(d_weight) + + @staticmethod + def from_modula(m, w=None): + """Convert from modula.atom.Embed""" + with torch.no_grad(): + embed = Embed(m.num_embed, m.d_embed, getattr(m, 'target_norm', 1.0)) + if w is not None: + embed.weight.copy_(w) + return embed + + def __repr__(self): + return f"Embed(num_embed={self.num_embed}, d_embed={self.d_embed}, target_norm={self.target_norm})" + +# Bond + +class ReLU(nn.ReLU): + pass # No changes needed, inherits from nn.ReLU + +class GeLU(nn.GELU): + def __init__(self): + super().__init__(approximate='tanh') + + def forward(self, x): + return super().forward(x) / 1.1289 # 1.1289 is the max derivative of gelu(x) + + +class SplitIntoHeads(nn.Module): + """Reshapes an input to have heads. + + Input shape: (batch_size, sequence_length, embed_dim) + Output shape: (batch_size, num_heads, sequence_length, head_size) + + Adapted from Karpathy's nanoGPT. + """ + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, x): + B, T, D = x.shape + return x.reshape(B, T, self.num_heads, D // self.num_heads).transpose(1, 2) + + @staticmethod + def from_modula(m): + """Convert from modula.bond.SplitIntoHeads""" + return SplitIntoHeads(m.num_heads) + + +class MergeHeads(nn.Module): + """Inverse of SplitIntoHeads.""" + def forward(self, x): + B, num_heads, T, head_dim = x.shape + return x.transpose(1, 2).reshape(B, T, num_heads * head_dim) + + +class AttentionQK(nn.Module): + """Computes the query and key matrix multiplication in attention.""" + def __init__(self): + super().__init__() + + def forward(self, x): + q, k = x # both shape [batch, n_heads, seq_len, d_query] + scale = 1 / q.shape[-1] + scores = q @ k.transpose(-2, -1) * scale + return scores # shape [batch, n_heads, seq_len, seq_len] + + +class CausalMask(nn.Module): + """Masks the upper triangular part of the attention scores.""" + def __init__(self): + super().__init__() + + def forward(self, x): + scores = x + seq_len = scores.shape[-1] + mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=scores.device)) + return torch.where(mask, scores, torch.tensor(-float('inf'), device=scores.device)) + + +class Softmax(nn.Module): + """Softmax with a sharpness parameter.""" + def __init__(self, scale): + super().__init__() + self.sensitivity = scale + + def forward(self, x): + return F.softmax(self.sensitivity * x, dim=-1) + + @staticmethod + def from_modula(m): + """Convert from modula.bond.Softmax""" + return Softmax(m.sensitivity) + + +class ApplyAttentionScores(nn.Module): + """Computes attention values from the scores.""" + def __init__(self): + super().__init__() + + def forward(self, x): + v, scores = x + return scores @ v + +class ScaledDotProductAttention(nn.Module): + def __init__(self, scale): + super().__init__() + self.sensitivity = scale + + def forward(self, x): + v, (q, k) = x + scale = self.sensitivity / q.shape[-1] + out = F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) + return out + +def apply_rotary_emb(x, cos, sin): + assert x.ndim == 4 # multihead attention + d = x.shape[3] // 2 + x1 = x[..., d:] # Second half first + x2 = x[..., :d] # First half second + y1 = cos * x1 + sin * x2 + y2 = -sin * x1 + cos * x2 + return torch.cat([y1, y2], 3) + +class Rope(nn.Module): + def __init__(self, d_head, base=10000): + super().__init__() + self.rope_dim = d_head // 2 + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dim).float() / self.rope_dim)) + self.register_buffer("inv_freq", inv_freq) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + + def forward(self, x): + q, k = x + seq_len = q.shape[2] # Assuming shape [batch, n_heads, seq_len, d_head] + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=q.device).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq).to(q.device) + self.cos_cached = freqs.cos() + self.sin_cached = freqs.sin() + + # Shape: [1, 1, seq_len, rope_dim] to match [batch, n_heads, seq_len, d_head] + cos = self.cos_cached[None, None, :, :] + sin = self.sin_cached[None, None, :, :] + + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + return q, k + + @staticmethod + def from_modula(m): + """Convert from modula.bond.Rope""" + return Rope(2 * m.rope_dim, base=m.base) + +# Abstract + +class Identity(nn.Module): + def forward(self, x): + return x + +class Add(nn.Module): + def forward(self, x): + a, b = x + return a + b + +class Mul(nn.Module): + def __init__(self, scalar): + super().__init__() + self.sensitivity = scalar + + def forward(self, x): + return x * self.sensitivity + + @staticmethod + def from_modula(m): + """Convert from modula.abstract.Mul""" + return Mul(m.sensitivity) + +# Composite +class Parallel(nn.Module): + def __init__(self, *modules): + super().__init__() + self.module_list = nn.ModuleList(*modules) + + def __getitem__(self, idx): + return self.module_list[idx] + + def forward(self, inputs): + return [module(inputs) for module in self.module_list] + +def FlashAttention(num_heads, d_embed, d_query, d_value, attention_scale): + return nn.Sequential( + Parallel([ + nn.Sequential( # V + Linear(num_heads * d_value, d_embed), + SplitIntoHeads(num_heads=num_heads) + ), + nn.Sequential( + Parallel([ + nn.Sequential( # Q + Linear(num_heads * d_query, d_embed), + SplitIntoHeads(num_heads=num_heads) + ), + nn.Sequential( # K + Linear(num_heads * d_query, d_embed), + SplitIntoHeads(num_heads=num_heads) + ) + ]), + Rope(d_query), + ) + ]), + ScaledDotProductAttention(attention_scale), + Mul(1/3), + MergeHeads(), + Linear(d_embed, num_heads * d_value) + ) From 1dbd5a2dbf3fa80474f85d1be1b689af165b252a Mon Sep 17 00:00:00 2001 From: Gavia Gray Date: Wed, 22 Oct 2025 18:33:02 -0400 Subject: [PATCH 2/2] make first example in to_pytorch actually work --- modula/to_pytorch.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/modula/to_pytorch.py b/modula/to_pytorch.py index 005b881..fc01978 100644 --- a/modula/to_pytorch.py +++ b/modula/to_pytorch.py @@ -250,12 +250,14 @@ def flash_sequentialise(module): torch.manual_seed(0) # For reproducibility in PyTorch # Example usage - - module = modula.atom.Linear(fanout=4, fanin=3) - module @= (modula.atom.Linear(fanout=2, fanin=4), modula.atom.Linear(fanout=2, fanin=4)) - module @= modula.atom.Linear(fanout=2, fanin=4) + module = modula.abstract.TupleModule([modula.atom.Linear(fanout=2, fanin=4), + modula.atom.Linear(fanout=2, fanin=4)]) + module @= modula.atom.Linear(fanout=4, fanin=3) print(sequentialise(module)) + x = torch.randn(5,3) + print(f"input: {x.shape}") + print("output:", [y.shape for y in sequentialise(module)(x)]) attention = Attention(num_heads=2, d_embed=8, d_query=4, d_value=4, attention_scale=1.0) weights = attention.initialize(key)