Skip to content
Open

Fix 934 #1155

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,4 @@
strictDictionaryInference=true
strictListInference=true
strictParameterNoneValue=true
strictSetInference=true
strictSetInference=true
43 changes: 31 additions & 12 deletions tests/unit/factored_matrix/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ def factored_matrices_leading_ones(random_matrices_leading_ones):
return [FactoredMatrix(a, b) for a, b in random_matrices_leading_ones]


@pytest.fixture(scope="module")
def random_matrices_bf16():
return [
(randn(3, 2).to(torch.bfloat16), randn(2, 3).to(torch.bfloat16)),
(randn(10, 4).to(torch.bfloat16), randn(4, 10).to(torch.bfloat16)),
]


@pytest.fixture(scope="module")
def factored_matrices_bf16(random_matrices_bf16):
return [FactoredMatrix(a, b) for a, b in random_matrices_bf16]


class TestFactoredMatrixProperties:
def test_AB_property(self, factored_matrices, random_matrices):
for i, factored_matrix in enumerate(factored_matrices):
Expand Down Expand Up @@ -79,18 +92,6 @@ def test_svd_property_leading_ones(self, factored_matrices_leading_ones):
assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5)
assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5)

@pytest.mark.skip(
"""
Jaxtyping throws a TypeError when this test is run.
TypeError: type of the return value must be jaxtyping.Float[Tensor, '*leading_dims mdim']; got torch.Tensor instead

I'm not sure why. The error is not very informative. When debugging the shape was equal to mdim, and *leading_dims should
match zero or more leading dims according to the [docs](https://github.com/google/jaxtyping/blob/main/API.md).

Sort of related to https://github.com/TransformerLensOrg/TransformerLens/issues/190 because jaxtyping
is only enabled at test time and not runtime.
"""
)
def test_eigenvalues_property(self, factored_matrices):
for factored_matrix in factored_matrices:
if factored_matrix.ldim == factored_matrix.rdim:
Expand Down Expand Up @@ -159,3 +160,21 @@ def test_unsqueeze(self, factored_matrices_leading_ones):
assert isinstance(result, FactoredMatrix)
assert torch.allclose(result.A, unsqueezed_A)
assert torch.allclose(result.B, unsqueezed_B)

def test_eigenvalues_bfloat16_support(self, factored_matrices_bf16):
"""
Test that eigenvalues calculation does nott crash for bfloat16 matrices.
"""
for factored_matrix in factored_matrices_bf16:
if factored_matrix.ldim == factored_matrix.rdim:
eigenvalues = factored_matrix.eigenvalues

assert eigenvalues.dtype == torch.complex64

expected_eigenvalues = torch.linalg.eig(
factored_matrix.BA.to(torch.float32)
).eigenvalues

assert torch.allclose(
torch.abs(eigenvalues), torch.abs(expected_eigenvalues), atol=1e-2, rtol=1e-2
)
15 changes: 11 additions & 4 deletions transformer_lens/FactoredMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List, Tuple, Union, overload

import torch
from jaxtyping import Float
from jaxtyping import Complex, Float

import transformer_lens.utils as utils

Expand Down Expand Up @@ -189,9 +189,16 @@ def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]:
return self.svd()[2]

@property
def eigenvalues(self) -> Float[torch.Tensor, "*leading_dims mdim"]:
"""Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k."""
return torch.linalg.eig(self.BA).eigenvalues
def eigenvalues(self) -> Complex[torch.Tensor, "*leading_dims mdim"]:
"""
Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv,
so Av is an eigenvector of AB with eigenvalue k.
"""
input_matrix = self.BA
if input_matrix.dtype in [torch.bfloat16, torch.float16]:
# Cast to float32 because eig is not implemented for 16-bit on CPU/CUDA
input_matrix = input_matrix.to(torch.float32)
return torch.linalg.eig(input_matrix).eigenvalues

def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple:
"""
Expand Down
Loading