Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions modula/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,137 @@ def __init__(self, scalar):

def forward(self, x, w):
return x * self.sensitivity

def get_leaf_modules(module):
"""
Walk through a module tree and return the leaf modules (Atom or Bond instances)
in the same order as the corresponding weights would be in the list returned by initialize().

Args:
module: A Module instance (typically CompositeModule at top level)

Returns:
List of leaf modules (Atom or Bond instances)
"""
# Base case: if this is a leaf module (Atom or Bond)
if isinstance(module, (Atom, Bond)):
return [module]

# If this is a CompositeModule
elif isinstance(module, CompositeModule):
m0, m1 = module.children
# Order matches initialize(): m0 weights first, then m1 weights
return get_leaf_modules(m0) + get_leaf_modules(m1)

# If this is a TupleModule
elif isinstance(module, TupleModule):
leaf_modules = []
# Order matches initialize(): iterate through children in order
for child in module.children:
leaf_modules.extend(get_leaf_modules(child))
return leaf_modules

# For any other Module type, assume no children or handle as needed
else:
return []

def get_leaf_target_norms(module, target_norm=1.0):
"""
Walk through a module tree the same way dualize() does and compute the target norm
that would be passed to each leaf module's dualize method, then store it in the module.

Args:
module: A Module instance
target_norm: The target norm passed to this module's dualize method

Returns:
List of target norms for leaf modules, in the same order as get_leaf_modules()
"""
# Base case: if this is a leaf module (Atom or Bond)
if isinstance(module, (Atom, Bond)):
return [target_norm]

# If this is a CompositeModule
elif isinstance(module, CompositeModule):
if module.mass > 0:
m0, m1 = module.children
# Same logic as in CompositeModule.dualize()
m0.target_norm = target_norm * m0.mass / module.mass / m1.sensitivity
m1.target_norm = target_norm * m1.mass / module.mass

# Recursively get target norms for children (order: m0 first, then m1)
return (get_leaf_target_norms(m0, m0.target_norm) +
get_leaf_target_norms(m1, m1.target_norm))
else:
# When mass is 0, we still need to traverse to get the structure right
m0, m1 = module.children
return (get_leaf_target_norms(m0, 0.0) +
get_leaf_target_norms(m1, 0.0))

# If this is a TupleModule
elif isinstance(module, TupleModule):
if module.mass > 0:
target_norms = []
# Same logic as in TupleModule.dualize()
for child in module.children:
child.target_norm = target_norm * child.mass / module.mass
target_norms.extend(get_leaf_target_norms(child, child.target_norm))
return target_norms
else:
# When mass is 0, we still need to traverse to get the structure right
target_norms = []
for child in module.children:
target_norms.extend(get_leaf_target_norms(child, 0.0))
return target_norms

# For any other Module type, assume no children
else:
return []

def traverse_forward_order(module, func):
"""
Traverse composite modules in the order that forward() will execute them,
applying func to each module.

Args:
module: The module to traverse
func: Function to apply to each module. Should accept a module as argument.
"""
def _traverse(mod):
if isinstance(mod, CompositeModule):
# For composite modules, traverse m0 first, then m1 (execution order)
m0, m1 = mod.children
_traverse(m0)
_traverse(m1)
elif isinstance(mod, TupleModule):
# For tuple modules, traverse all children (they execute in parallel)
for child in mod.children:
_traverse(child)

# Apply function to current module after traversing children
func(mod)

_traverse(module)

def set_atomic_weights(module, weights):
"""
Set weights as attributes on atomic modules by traversing in forward execution order. Assumes each atomic module has only one weight.

Args:
module: The root module to traverse
weights: List of weights corresponding to atomic modules
"""
weight_index = [0] # Use list to make it mutable in closure

def assign_weight(mod):
if isinstance(mod, Atom):
if weight_index[0] < len(weights):
mod.weight = weights[weight_index[0]]
weight_index[0] += 1

traverse_forward_order(module, assign_weight)

# Validate that we used all weights
if weight_index[0] != len(weights):
raise ValueError(f"Number of weights ({len(weights)}) doesn't match number of atomic modules ({weight_index[0]})")

71 changes: 70 additions & 1 deletion modula/atom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp

from modula.abstract import Atom
from modula.abstract import Atom, Bond, CompositeModule, TupleModule

def orthogonalize(M):
# six step Newton-Schulz by @YouJiacheng
Expand Down Expand Up @@ -92,6 +92,73 @@ def dualize(self, grad_w, target_norm=1.0):


if __name__ == "__main__":
from modula.abstract import get_leaf_modules, get_leaf_target_norms

def test_dualize_consistency(module, grad_w, target_norm=1.0, rtol=1e-6):
"""
Test that get_unnormalized_dual and get_leaf_target_norms produce results
consistent with the actual dualize method.

Args:
module: A Module instance
grad_w: Weight gradient list
target_norm: Target norm to test with
rtol: Relative tolerance for comparison

Returns:
bool: True if consistent, False otherwise
"""
# Get results from actual dualize
actual_dual = module.dualize(grad_w, target_norm=target_norm)

# Get results from our functions
unnormalized_dual = get_unnormalized_dual(module, grad_w)
leaf_modules = get_leaf_modules(module)
target_norms = get_leaf_target_norms(module, target_norm=target_norm)

# Apply target norms to unnormalized dual
predicted_dual = []
weight_idx = 0

for leaf_module, leaf_target_norm in zip(leaf_modules, target_norms):
if isinstance(leaf_module, (Atom)): # Only atoms have weights
leaf_weights = unnormalized_dual[weight_idx:weight_idx + leaf_module.atoms]
# Apply the target norm
scaled_weights = [w * leaf_target_norm for w in leaf_weights]
predicted_dual.extend(scaled_weights)
weight_idx += leaf_module.atoms
# Bonds have no weights, so nothing to add to predicted_dual

# Compare actual vs predicted
if len(actual_dual) != len(predicted_dual):
print(f"Length mismatch: actual {len(actual_dual)}, predicted {len(predicted_dual)}")
return False

for i, (actual, predicted) in enumerate(zip(actual_dual, predicted_dual)):
if not jnp.allclose(actual, predicted, rtol=rtol):
print(f"Mismatch at weight {i}")
print(f"Actual shape: {actual.shape}, Predicted shape: {predicted.shape}")
print(f"Max difference: {jnp.max(jnp.abs(actual - predicted))}")
return False

print("✓ Dualize consistency test passed!")
return True

# Example usage:
def test_example():
"""Example test with a simple module"""
# Create a simple module
linear = Linear(fanout=4, fanin=3)
linear @= Linear(fanout=4, fanin=4) # Add another linear layer
linear @= Linear(fanout=2, fanin=4) # Add another linear layer

# Initialize weights and create some gradient
key = jax.random.PRNGKey(42)
weights = linear.initialize(key)
grad_w = [jax.random.normal(key, shape=w.shape) for w in weights]

# Test consistency
return test_dualize_consistency(linear, grad_w, target_norm=2.5)
Comment on lines +95 to +161
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops I didn't mean to leave this in, I was running some unrelated tests. I'll remove this.


key = jax.random.PRNGKey(0)

Expand All @@ -116,3 +183,5 @@ def dualize(self, grad_w, target_norm=1.0):
error_O = jnp.linalg.norm(O - U @ Vh) / jnp.linalg.norm(U @ Vh)
print(f"relative error in M's SVD: {error_M}")
print(f"relative error in O: {error_O}")

test_example()
3 changes: 2 additions & 1 deletion modula/bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self, d_head, base=10000):
self.sensitivity = 1 # rope is an orthogonal transformation

self.rope_dim = d_head // 2
self.base = base
self.inverse_frequencies = 1/base**(jnp.arange(self.rope_dim) / self.rope_dim)
self.seq_len_cached = None
self.sin_cached = None
Expand All @@ -126,7 +127,7 @@ def rotate(self, x):
x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim]
x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim]

cos, sin = self.get_cached(seq_len)
sin, cos = self.get_cached(seq_len)
y1 = cos * x1 + sin * x2
y2 = -sin * x1 + cos * x2

Expand Down
Loading