diff --git a/LLMs/torch_examples/pooling.py b/LLMs/torch_examples/pooling.py new file mode 100644 index 0000000..085409b --- /dev/null +++ b/LLMs/torch_examples/pooling.py @@ -0,0 +1,18 @@ +import torch + +class VariableSortedHistoryPooling(torch.nn.Module): + def __init__(self, n_samples: int, emb_dim: int): + super(VariableSortedHistoryPooling, self).__init__() + # n samples are n events, where it's consecutive events belonging to a given user + # The n samples can be segmented into B users. + self.emb = torch.nn.Embedding(n_samples, emb_dim) + def forward(self, event_indices: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: + event_embs = self.emb(event_indices) + # diffs of cumulative offsets gives user lengths (number of events in history per user) + user_lengths = offsets[1:] - offsets[:-1] + user_ids = torch.repeat_interleave(torch.arange(len(user_lengths), + device=offsets.device), + user_lengths) + target = torch.zeros(len(user_lengths), event_embs.shape[1], device=event_embs.device) + target = target.scatter_add(dim=0, index=user_ids.unsqueeze(1).expand_as(event_embs), src=event_embs) + return target / user_lengths.clamp(min=1).unsqueeze(1) \ No newline at end of file diff --git a/LLMs/torch_examples/shape_literacy.py b/LLMs/torch_examples/shape_literacy.py new file mode 100644 index 0000000..be58126 --- /dev/null +++ b/LLMs/torch_examples/shape_literacy.py @@ -0,0 +1,37 @@ +import torch + +x = torch.randn(10) +y = x.unsqueeze(1) +print(x.shape,y.shape) + +x = torch.randn(4, 1, 8) +y = x.squeeze() +print(x.shape,y.shape) + +x = torch.randn(2,3,4) +print(x, x.shape) +# reshape to same number of elements, but different shape +y = x.view(12,2) +print(y, y.shape) +z = x.reshape(12,2) +print(z, z.shape) + +x = torch.randn(5, 1) +z = x.expand(5,3) +print(x,y) +print(x.shape, z.shape) + +# copies the data +y = x.repeat(1,3) +print(y) +print(y.shape) +y[0, 0] = 10 +print("Y", y) + + +y2 = x.expand_as(torch.randn(5, 10)) +# expanded size to match existing size at dim 0. +# expand doesn't allocate new memory, so changing y2 will change x +y3 = x.expand(5, 2) +y3[0, 0] = 20 +print(y3) \ No newline at end of file diff --git a/tests/test_pooling.py b/tests/test_pooling.py new file mode 100644 index 0000000..c894e59 --- /dev/null +++ b/tests/test_pooling.py @@ -0,0 +1,36 @@ +import torch +import pytest +from LLMs.torch_examples.pooling import VariableSortedHistoryPooling + +def loop_pooling(emb, event_indices, offsets): + pooled = [] + for i in range(len(offsets) - 1): + start, end = offsets[i].item(), offsets[i + 1].item() + user_embs = emb(event_indices[start:end]) + pooled.append(user_embs.mean(dim=0)) + return torch.stack(pooled, dim=0) + + +def test_variable_sorted_history_pooling(): + torch.manual_seed(0) + + # ----- Test configuration ----- + vocab_size = 50 + emb_dim = 8 + B = 4 # number of users + lengths = torch.tensor([3, 1, 4, 2]) + offsets = torch.cat([torch.tensor([0]), lengths.cumsum(0)]) + N = offsets[-1].item() + + event_indices = torch.randint(0, vocab_size, (N,)) + + # ----- Model ----- + model = VariableSortedHistoryPooling(vocab_size, emb_dim) + + # ----- Forward passes ----- + out_vec = model(event_indices, offsets) + out_loop = loop_pooling(model.emb, event_indices, offsets) + + # ----- Assertions ----- + assert out_vec.shape == (B, emb_dim) + assert torch.allclose(out_vec, out_loop, atol=1e-6)