Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions LLMs/torch_examples/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

class VariableSortedHistoryPooling(torch.nn.Module):
def __init__(self, n_samples: int, emb_dim: int):
super(VariableSortedHistoryPooling, self).__init__()
# n samples are n events, where it's consecutive events belonging to a given user
# The n samples can be segmented into B users.
self.emb = torch.nn.Embedding(n_samples, emb_dim)
def forward(self, event_indices: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
event_embs = self.emb(event_indices)
# diffs of cumulative offsets gives user lengths (number of events in history per user)
user_lengths = offsets[1:] - offsets[:-1]
user_ids = torch.repeat_interleave(torch.arange(len(user_lengths),
device=offsets.device),
user_lengths)
target = torch.zeros(len(user_lengths), event_embs.shape[1], device=event_embs.device)
target = target.scatter_add(dim=0, index=user_ids.unsqueeze(1).expand_as(event_embs), src=event_embs)
return target / user_lengths.clamp(min=1).unsqueeze(1)
37 changes: 37 additions & 0 deletions LLMs/torch_examples/shape_literacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

x = torch.randn(10)
y = x.unsqueeze(1)
print(x.shape,y.shape)

x = torch.randn(4, 1, 8)
y = x.squeeze()
print(x.shape,y.shape)

x = torch.randn(2,3,4)
print(x, x.shape)
# reshape to same number of elements, but different shape
y = x.view(12,2)
print(y, y.shape)
z = x.reshape(12,2)
print(z, z.shape)

x = torch.randn(5, 1)
z = x.expand(5,3)
print(x,y)
print(x.shape, z.shape)

# copies the data
y = x.repeat(1,3)
print(y)
print(y.shape)
y[0, 0] = 10
print("Y", y)


y2 = x.expand_as(torch.randn(5, 10))
# expanded size to match existing size at dim 0.
# expand doesn't allocate new memory, so changing y2 will change x
y3 = x.expand(5, 2)
y3[0, 0] = 20
print(y3)
36 changes: 36 additions & 0 deletions tests/test_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import pytest
from LLMs.torch_examples.pooling import VariableSortedHistoryPooling

def loop_pooling(emb, event_indices, offsets):
pooled = []
for i in range(len(offsets) - 1):
start, end = offsets[i].item(), offsets[i + 1].item()
user_embs = emb(event_indices[start:end])
pooled.append(user_embs.mean(dim=0))
return torch.stack(pooled, dim=0)


def test_variable_sorted_history_pooling():
torch.manual_seed(0)

# ----- Test configuration -----
vocab_size = 50
emb_dim = 8
B = 4 # number of users
lengths = torch.tensor([3, 1, 4, 2])
offsets = torch.cat([torch.tensor([0]), lengths.cumsum(0)])
N = offsets[-1].item()

event_indices = torch.randint(0, vocab_size, (N,))

# ----- Model -----
model = VariableSortedHistoryPooling(vocab_size, emb_dim)

# ----- Forward passes -----
out_vec = model(event_indices, offsets)
out_loop = loop_pooling(model.emb, event_indices, offsets)

# ----- Assertions -----
assert out_vec.shape == (B, emb_dim)
assert torch.allclose(out_vec, out_loop, atol=1e-6)