Skip to content

Commit f5c65dd

Browse files
committed
Add the layout for testing
1 parent 7416cbf commit f5c65dd

File tree

7 files changed

+1067
-3
lines changed

7 files changed

+1067
-3
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for MLIR generation."""
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Pytest configuration for MLIR generation tests.
3+
4+
This file sets up fixtures and configuration for tests in this directory.
5+
"""
6+
7+
import os
8+
import pytest
9+
10+
11+
@pytest.fixture(scope="session", autouse=True)
12+
def setup_mlir_environment():
13+
"""
14+
Set up MLIR environment variables for testing.
15+
This runs once per test session.
16+
"""
17+
# Set environment variable for MLIR shared libraries if needed
18+
# The default empty string is fine for most cases
19+
if "LIGHTHOUSE_SHARED_LIBS" not in os.environ:
20+
os.environ["LIGHTHOUSE_SHARED_LIBS"] = ""
21+
22+
yield
23+
24+
# Cleanup after all tests (if needed)
25+
pass
26+
27+
28+
@pytest.fixture
29+
def mlir_context():
30+
"""
31+
Provide a fresh MLIR context for each test.
32+
"""
33+
from mlir import ir
34+
35+
return ir.Context()
36+
37+
38+
@pytest.fixture
39+
def sample_shapes():
40+
"""
41+
Provide common tensor shapes for testing.
42+
"""
43+
return [
44+
(4, 16),
45+
(8, 8),
46+
(16, 32),
47+
(1, 64),
48+
]
49+
50+
51+
@pytest.fixture
52+
def sample_types():
53+
"""
54+
Provide common element types for testing.
55+
"""
56+
return ["f32", "f64"]
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
3+
4+
import math
5+
from dataclasses import dataclass
6+
from typing import Optional, Tuple
7+
8+
import fairscale.nn.model_parallel.initialize as fs_init
9+
import torch
10+
import torch.nn.functional as F
11+
from fairscale.nn.model_parallel.layers import (
12+
ColumnParallelLinear,
13+
RowParallelLinear,
14+
VocabParallelEmbedding,
15+
)
16+
from torch import nn
17+
18+
19+
@dataclass
20+
class ModelArgs:
21+
dim: int = 4096
22+
n_layers: int = 32
23+
n_heads: int = 32
24+
n_kv_heads: Optional[int] = None
25+
vocab_size: int = -1
26+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27+
ffn_dim_multiplier: Optional[float] = None
28+
norm_eps: float = 1e-5
29+
rope_theta: float = 500000
30+
31+
max_batch_size: int = 32
32+
max_seq_len: int = 2048
33+
34+
35+
class RMSNorm(torch.nn.Module):
36+
def __init__(self, dim: int, eps: float = 1e-6):
37+
super().__init__()
38+
self.eps = eps
39+
self.weight = nn.Parameter(torch.ones(dim))
40+
41+
def _norm(self, x):
42+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
43+
44+
def forward(self, x):
45+
output = self._norm(x.float()).type_as(x)
46+
return output * self.weight
47+
48+
49+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
50+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
51+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
52+
freqs = torch.outer(t, freqs)
53+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
54+
return freqs_cis
55+
56+
57+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
58+
ndim = x.ndim
59+
assert 0 <= 1 < ndim
60+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
61+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
62+
return freqs_cis.view(*shape)
63+
64+
65+
def apply_rotary_emb(
66+
xq: torch.Tensor,
67+
xk: torch.Tensor,
68+
freqs_cis: torch.Tensor,
69+
) -> Tuple[torch.Tensor, torch.Tensor]:
70+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
71+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
72+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
73+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
74+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
75+
return xq_out.type_as(xq), xk_out.type_as(xk)
76+
77+
78+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
79+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
80+
bs, slen, n_kv_heads, head_dim = x.shape
81+
if n_rep == 1:
82+
return x
83+
return (
84+
x[:, :, :, None, :]
85+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
86+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
87+
)
88+
89+
90+
class Attention(nn.Module):
91+
def __init__(self, args: ModelArgs):
92+
super().__init__()
93+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
94+
model_parallel_size = fs_init.get_model_parallel_world_size()
95+
self.n_local_heads = args.n_heads // model_parallel_size
96+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
97+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
98+
self.head_dim = args.dim // args.n_heads
99+
100+
self.wq = ColumnParallelLinear(
101+
args.dim,
102+
args.n_heads * self.head_dim,
103+
bias=False,
104+
gather_output=False,
105+
init_method=lambda x: x,
106+
)
107+
self.wk = ColumnParallelLinear(
108+
args.dim,
109+
self.n_kv_heads * self.head_dim,
110+
bias=False,
111+
gather_output=False,
112+
init_method=lambda x: x,
113+
)
114+
self.wv = ColumnParallelLinear(
115+
args.dim,
116+
self.n_kv_heads * self.head_dim,
117+
bias=False,
118+
gather_output=False,
119+
init_method=lambda x: x,
120+
)
121+
self.wo = RowParallelLinear(
122+
args.n_heads * self.head_dim,
123+
args.dim,
124+
bias=False,
125+
input_is_parallel=True,
126+
init_method=lambda x: x,
127+
)
128+
129+
self.cache_k = torch.zeros(
130+
(
131+
args.max_batch_size,
132+
args.max_seq_len,
133+
self.n_local_kv_heads,
134+
self.head_dim,
135+
)
136+
).cuda()
137+
self.cache_v = torch.zeros(
138+
(
139+
args.max_batch_size,
140+
args.max_seq_len,
141+
self.n_local_kv_heads,
142+
self.head_dim,
143+
)
144+
).cuda()
145+
146+
def forward(
147+
self,
148+
x: torch.Tensor,
149+
start_pos: int,
150+
freqs_cis: torch.Tensor,
151+
mask: Optional[torch.Tensor],
152+
):
153+
bsz, seqlen, _ = x.shape
154+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
155+
156+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
157+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
158+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
159+
160+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
161+
162+
self.cache_k = self.cache_k.to(xq)
163+
self.cache_v = self.cache_v.to(xq)
164+
165+
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
166+
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
167+
168+
keys = self.cache_k[:bsz, : start_pos + seqlen]
169+
values = self.cache_v[:bsz, : start_pos + seqlen]
170+
171+
# repeat k/v heads if n_kv_heads < n_heads
172+
keys = repeat_kv(
173+
keys, self.n_rep
174+
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
175+
values = repeat_kv(
176+
values, self.n_rep
177+
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
178+
179+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
180+
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
181+
values = values.transpose(
182+
1, 2
183+
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
184+
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
185+
if mask is not None:
186+
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
187+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
188+
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
189+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
190+
return self.wo(output)
191+
192+
193+
class FeedForward(nn.Module):
194+
def __init__(
195+
self,
196+
dim: int,
197+
hidden_dim: int,
198+
multiple_of: int,
199+
ffn_dim_multiplier: Optional[float],
200+
):
201+
super().__init__()
202+
hidden_dim = int(2 * hidden_dim / 3)
203+
# custom dim factor multiplier
204+
if ffn_dim_multiplier is not None:
205+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
206+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
207+
208+
self.w1 = ColumnParallelLinear(
209+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
210+
)
211+
self.w2 = RowParallelLinear(
212+
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
213+
)
214+
self.w3 = ColumnParallelLinear(
215+
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
216+
)
217+
218+
def forward(self, x):
219+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
220+
221+
222+
class TransformerBlock(nn.Module):
223+
def __init__(self, layer_id: int, args: ModelArgs):
224+
super().__init__()
225+
self.n_heads = args.n_heads
226+
self.dim = args.dim
227+
self.head_dim = args.dim // args.n_heads
228+
self.attention = Attention(args)
229+
self.feed_forward = FeedForward(
230+
dim=args.dim,
231+
hidden_dim=4 * args.dim,
232+
multiple_of=args.multiple_of,
233+
ffn_dim_multiplier=args.ffn_dim_multiplier,
234+
)
235+
self.layer_id = layer_id
236+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
237+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
238+
239+
def forward(
240+
self,
241+
x: torch.Tensor,
242+
start_pos: int,
243+
freqs_cis: torch.Tensor,
244+
mask: Optional[torch.Tensor],
245+
):
246+
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
247+
out = h + self.feed_forward(self.ffn_norm(h))
248+
return out
249+
250+
251+
class Transformer(nn.Module):
252+
def __init__(self, params: ModelArgs):
253+
super().__init__()
254+
self.params = params
255+
self.vocab_size = params.vocab_size
256+
self.n_layers = params.n_layers
257+
258+
self.tok_embeddings = VocabParallelEmbedding(
259+
params.vocab_size, params.dim, init_method=lambda x: x
260+
)
261+
262+
self.layers = torch.nn.ModuleList()
263+
for layer_id in range(params.n_layers):
264+
self.layers.append(TransformerBlock(layer_id, params))
265+
266+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
267+
self.output = ColumnParallelLinear(
268+
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
269+
)
270+
271+
self.freqs_cis = precompute_freqs_cis(
272+
params.dim // params.n_heads,
273+
params.max_seq_len * 2,
274+
params.rope_theta,
275+
)
276+
277+
@torch.inference_mode()
278+
def forward(self, tokens: torch.Tensor, start_pos: int):
279+
_bsz, seqlen = tokens.shape
280+
h = self.tok_embeddings(tokens)
281+
self.freqs_cis = self.freqs_cis.to(h.device)
282+
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
283+
284+
mask = None
285+
if seqlen > 1:
286+
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
287+
288+
mask = torch.triu(mask, diagonal=1)
289+
290+
# When performing key-value caching, we compute the attention scores
291+
# only for the new sequence. Thus, the matrix of scores is of size
292+
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
293+
# j > cache_len + i, since row i corresponds to token cache_len + i.
294+
mask = torch.hstack(
295+
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
296+
).type_as(h)
297+
298+
for layer in self.layers:
299+
h = layer(h, start_pos, freqs_cis, mask)
300+
h = self.norm(h)
301+
output = self.output(h).float()
302+
return output

0 commit comments

Comments
 (0)