diff --git a/poetry.lock b/poetry.lock index 62bb88aac..cba4dcb0c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6182,4 +6182,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "7572fab0b6315f222cb5a72895eefa24ac9ca8b7e91a0a418e21c226072474cf" +content-hash = "7572fab0b6315f222cb5a72895eefa24ac9ca8b7e91a0a418e21c226072474cf" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9fcce44d4..b5d8424e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,4 +185,4 @@ strictDictionaryInference=true strictListInference=true strictParameterNoneValue=true - strictSetInference=true + strictSetInference=true \ No newline at end of file diff --git a/tests/unit/factored_matrix/test_properties.py b/tests/unit/factored_matrix/test_properties.py index 28a7ad87b..e7e36fdca 100644 --- a/tests/unit/factored_matrix/test_properties.py +++ b/tests/unit/factored_matrix/test_properties.py @@ -32,6 +32,19 @@ def factored_matrices_leading_ones(random_matrices_leading_ones): return [FactoredMatrix(a, b) for a, b in random_matrices_leading_ones] +@pytest.fixture(scope="module") +def random_matrices_bf16(): + return [ + (randn(3, 2).to(torch.bfloat16), randn(2, 3).to(torch.bfloat16)), + (randn(10, 4).to(torch.bfloat16), randn(4, 10).to(torch.bfloat16)), + ] + + +@pytest.fixture(scope="module") +def factored_matrices_bf16(random_matrices_bf16): + return [FactoredMatrix(a, b) for a, b in random_matrices_bf16] + + class TestFactoredMatrixProperties: def test_AB_property(self, factored_matrices, random_matrices): for i, factored_matrix in enumerate(factored_matrices): @@ -79,18 +92,6 @@ def test_svd_property_leading_ones(self, factored_matrices_leading_ones): assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5) assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5) - @pytest.mark.skip( - """ - Jaxtyping throws a TypeError when this test is run. - TypeError: type of the return value must be jaxtyping.Float[Tensor, '*leading_dims mdim']; got torch.Tensor instead - - I'm not sure why. The error is not very informative. When debugging the shape was equal to mdim, and *leading_dims should - match zero or more leading dims according to the [docs](https://github.com/google/jaxtyping/blob/main/API.md). - - Sort of related to https://github.com/TransformerLensOrg/TransformerLens/issues/190 because jaxtyping - is only enabled at test time and not runtime. - """ - ) def test_eigenvalues_property(self, factored_matrices): for factored_matrix in factored_matrices: if factored_matrix.ldim == factored_matrix.rdim: @@ -159,3 +160,21 @@ def test_unsqueeze(self, factored_matrices_leading_ones): assert isinstance(result, FactoredMatrix) assert torch.allclose(result.A, unsqueezed_A) assert torch.allclose(result.B, unsqueezed_B) + + def test_eigenvalues_bfloat16_support(self, factored_matrices_bf16): + """ + Test that eigenvalues calculation does nott crash for bfloat16 matrices. + """ + for factored_matrix in factored_matrices_bf16: + if factored_matrix.ldim == factored_matrix.rdim: + eigenvalues = factored_matrix.eigenvalues + + assert eigenvalues.dtype == torch.complex64 + + expected_eigenvalues = torch.linalg.eig( + factored_matrix.BA.to(torch.float32) + ).eigenvalues + + assert torch.allclose( + torch.abs(eigenvalues), torch.abs(expected_eigenvalues), atol=1e-2, rtol=1e-2 + ) diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 1e1c813a6..6c8140e87 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -10,7 +10,7 @@ from typing import List, Tuple, Union, overload import torch -from jaxtyping import Float +from jaxtyping import Complex, Float import transformer_lens.utils as utils @@ -189,9 +189,16 @@ def Vh(self) -> Float[torch.Tensor, "*leading_dims rdim mdim"]: return self.svd()[2] @property - def eigenvalues(self) -> Float[torch.Tensor, "*leading_dims mdim"]: - """Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k.""" - return torch.linalg.eig(self.BA).eigenvalues + def eigenvalues(self) -> Complex[torch.Tensor, "*leading_dims mdim"]: + """ + Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, + so Av is an eigenvector of AB with eigenvalue k. + """ + input_matrix = self.BA + if input_matrix.dtype in [torch.bfloat16, torch.float16]: + # Cast to float32 because eig is not implemented for 16-bit on CPU/CUDA + input_matrix = input_matrix.to(torch.float32) + return torch.linalg.eig(input_matrix).eigenvalues def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple: """