|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. |
| 3 | + |
| 4 | +import math |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import Optional, Tuple |
| 7 | + |
| 8 | +import fairscale.nn.model_parallel.initialize as fs_init |
| 9 | +import torch |
| 10 | +import torch.nn.functional as F |
| 11 | +from fairscale.nn.model_parallel.layers import ( |
| 12 | + ColumnParallelLinear, |
| 13 | + RowParallelLinear, |
| 14 | + VocabParallelEmbedding, |
| 15 | +) |
| 16 | +from torch import nn |
| 17 | + |
| 18 | + |
| 19 | +@dataclass |
| 20 | +class ModelArgs: |
| 21 | + dim: int = 4096 |
| 22 | + n_layers: int = 32 |
| 23 | + n_heads: int = 32 |
| 24 | + n_kv_heads: Optional[int] = None |
| 25 | + vocab_size: int = -1 |
| 26 | + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 |
| 27 | + ffn_dim_multiplier: Optional[float] = None |
| 28 | + norm_eps: float = 1e-5 |
| 29 | + rope_theta: float = 500000 |
| 30 | + |
| 31 | + max_batch_size: int = 32 |
| 32 | + max_seq_len: int = 2048 |
| 33 | + |
| 34 | + |
| 35 | +class RMSNorm(torch.nn.Module): |
| 36 | + def __init__(self, dim: int, eps: float = 1e-6): |
| 37 | + super().__init__() |
| 38 | + self.eps = eps |
| 39 | + self.weight = nn.Parameter(torch.ones(dim)) |
| 40 | + |
| 41 | + def _norm(self, x): |
| 42 | + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| 43 | + |
| 44 | + def forward(self, x): |
| 45 | + output = self._norm(x.float()).type_as(x) |
| 46 | + return output * self.weight |
| 47 | + |
| 48 | + |
| 49 | +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |
| 50 | + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| 51 | + t = torch.arange(end, device=freqs.device, dtype=torch.float32) |
| 52 | + freqs = torch.outer(t, freqs) |
| 53 | + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 |
| 54 | + return freqs_cis |
| 55 | + |
| 56 | + |
| 57 | +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
| 58 | + ndim = x.ndim |
| 59 | + assert 0 <= 1 < ndim |
| 60 | + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) |
| 61 | + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| 62 | + return freqs_cis.view(*shape) |
| 63 | + |
| 64 | + |
| 65 | +def apply_rotary_emb( |
| 66 | + xq: torch.Tensor, |
| 67 | + xk: torch.Tensor, |
| 68 | + freqs_cis: torch.Tensor, |
| 69 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 70 | + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| 71 | + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| 72 | + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) |
| 73 | + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| 74 | + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| 75 | + return xq_out.type_as(xq), xk_out.type_as(xk) |
| 76 | + |
| 77 | + |
| 78 | +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
| 79 | + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" |
| 80 | + bs, slen, n_kv_heads, head_dim = x.shape |
| 81 | + if n_rep == 1: |
| 82 | + return x |
| 83 | + return ( |
| 84 | + x[:, :, :, None, :] |
| 85 | + .expand(bs, slen, n_kv_heads, n_rep, head_dim) |
| 86 | + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) |
| 87 | + ) |
| 88 | + |
| 89 | + |
| 90 | +class Attention(nn.Module): |
| 91 | + def __init__(self, args: ModelArgs): |
| 92 | + super().__init__() |
| 93 | + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads |
| 94 | + model_parallel_size = fs_init.get_model_parallel_world_size() |
| 95 | + self.n_local_heads = args.n_heads // model_parallel_size |
| 96 | + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size |
| 97 | + self.n_rep = self.n_local_heads // self.n_local_kv_heads |
| 98 | + self.head_dim = args.dim // args.n_heads |
| 99 | + |
| 100 | + self.wq = ColumnParallelLinear( |
| 101 | + args.dim, |
| 102 | + args.n_heads * self.head_dim, |
| 103 | + bias=False, |
| 104 | + gather_output=False, |
| 105 | + init_method=lambda x: x, |
| 106 | + ) |
| 107 | + self.wk = ColumnParallelLinear( |
| 108 | + args.dim, |
| 109 | + self.n_kv_heads * self.head_dim, |
| 110 | + bias=False, |
| 111 | + gather_output=False, |
| 112 | + init_method=lambda x: x, |
| 113 | + ) |
| 114 | + self.wv = ColumnParallelLinear( |
| 115 | + args.dim, |
| 116 | + self.n_kv_heads * self.head_dim, |
| 117 | + bias=False, |
| 118 | + gather_output=False, |
| 119 | + init_method=lambda x: x, |
| 120 | + ) |
| 121 | + self.wo = RowParallelLinear( |
| 122 | + args.n_heads * self.head_dim, |
| 123 | + args.dim, |
| 124 | + bias=False, |
| 125 | + input_is_parallel=True, |
| 126 | + init_method=lambda x: x, |
| 127 | + ) |
| 128 | + |
| 129 | + self.cache_k = torch.zeros( |
| 130 | + ( |
| 131 | + args.max_batch_size, |
| 132 | + args.max_seq_len, |
| 133 | + self.n_local_kv_heads, |
| 134 | + self.head_dim, |
| 135 | + ) |
| 136 | + ).cuda() |
| 137 | + self.cache_v = torch.zeros( |
| 138 | + ( |
| 139 | + args.max_batch_size, |
| 140 | + args.max_seq_len, |
| 141 | + self.n_local_kv_heads, |
| 142 | + self.head_dim, |
| 143 | + ) |
| 144 | + ).cuda() |
| 145 | + |
| 146 | + def forward( |
| 147 | + self, |
| 148 | + x: torch.Tensor, |
| 149 | + start_pos: int, |
| 150 | + freqs_cis: torch.Tensor, |
| 151 | + mask: Optional[torch.Tensor], |
| 152 | + ): |
| 153 | + bsz, seqlen, _ = x.shape |
| 154 | + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
| 155 | + |
| 156 | + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 157 | + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 158 | + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 159 | + |
| 160 | + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) |
| 161 | + |
| 162 | + self.cache_k = self.cache_k.to(xq) |
| 163 | + self.cache_v = self.cache_v.to(xq) |
| 164 | + |
| 165 | + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk |
| 166 | + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv |
| 167 | + |
| 168 | + keys = self.cache_k[:bsz, : start_pos + seqlen] |
| 169 | + values = self.cache_v[:bsz, : start_pos + seqlen] |
| 170 | + |
| 171 | + # repeat k/v heads if n_kv_heads < n_heads |
| 172 | + keys = repeat_kv( |
| 173 | + keys, self.n_rep |
| 174 | + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) |
| 175 | + values = repeat_kv( |
| 176 | + values, self.n_rep |
| 177 | + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) |
| 178 | + |
| 179 | + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) |
| 180 | + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) |
| 181 | + values = values.transpose( |
| 182 | + 1, 2 |
| 183 | + ) # (bs, n_local_heads, cache_len + seqlen, head_dim) |
| 184 | + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) |
| 185 | + if mask is not None: |
| 186 | + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) |
| 187 | + scores = F.softmax(scores.float(), dim=-1).type_as(xq) |
| 188 | + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) |
| 189 | + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) |
| 190 | + return self.wo(output) |
| 191 | + |
| 192 | + |
| 193 | +class FeedForward(nn.Module): |
| 194 | + def __init__( |
| 195 | + self, |
| 196 | + dim: int, |
| 197 | + hidden_dim: int, |
| 198 | + multiple_of: int, |
| 199 | + ffn_dim_multiplier: Optional[float], |
| 200 | + ): |
| 201 | + super().__init__() |
| 202 | + hidden_dim = int(2 * hidden_dim / 3) |
| 203 | + # custom dim factor multiplier |
| 204 | + if ffn_dim_multiplier is not None: |
| 205 | + hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
| 206 | + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| 207 | + |
| 208 | + self.w1 = ColumnParallelLinear( |
| 209 | + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x |
| 210 | + ) |
| 211 | + self.w2 = RowParallelLinear( |
| 212 | + hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x |
| 213 | + ) |
| 214 | + self.w3 = ColumnParallelLinear( |
| 215 | + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x |
| 216 | + ) |
| 217 | + |
| 218 | + def forward(self, x): |
| 219 | + return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 220 | + |
| 221 | + |
| 222 | +class TransformerBlock(nn.Module): |
| 223 | + def __init__(self, layer_id: int, args: ModelArgs): |
| 224 | + super().__init__() |
| 225 | + self.n_heads = args.n_heads |
| 226 | + self.dim = args.dim |
| 227 | + self.head_dim = args.dim // args.n_heads |
| 228 | + self.attention = Attention(args) |
| 229 | + self.feed_forward = FeedForward( |
| 230 | + dim=args.dim, |
| 231 | + hidden_dim=4 * args.dim, |
| 232 | + multiple_of=args.multiple_of, |
| 233 | + ffn_dim_multiplier=args.ffn_dim_multiplier, |
| 234 | + ) |
| 235 | + self.layer_id = layer_id |
| 236 | + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| 237 | + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| 238 | + |
| 239 | + def forward( |
| 240 | + self, |
| 241 | + x: torch.Tensor, |
| 242 | + start_pos: int, |
| 243 | + freqs_cis: torch.Tensor, |
| 244 | + mask: Optional[torch.Tensor], |
| 245 | + ): |
| 246 | + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) |
| 247 | + out = h + self.feed_forward(self.ffn_norm(h)) |
| 248 | + return out |
| 249 | + |
| 250 | + |
| 251 | +class Transformer(nn.Module): |
| 252 | + def __init__(self, params: ModelArgs): |
| 253 | + super().__init__() |
| 254 | + self.params = params |
| 255 | + self.vocab_size = params.vocab_size |
| 256 | + self.n_layers = params.n_layers |
| 257 | + |
| 258 | + self.tok_embeddings = VocabParallelEmbedding( |
| 259 | + params.vocab_size, params.dim, init_method=lambda x: x |
| 260 | + ) |
| 261 | + |
| 262 | + self.layers = torch.nn.ModuleList() |
| 263 | + for layer_id in range(params.n_layers): |
| 264 | + self.layers.append(TransformerBlock(layer_id, params)) |
| 265 | + |
| 266 | + self.norm = RMSNorm(params.dim, eps=params.norm_eps) |
| 267 | + self.output = ColumnParallelLinear( |
| 268 | + params.dim, params.vocab_size, bias=False, init_method=lambda x: x |
| 269 | + ) |
| 270 | + |
| 271 | + self.freqs_cis = precompute_freqs_cis( |
| 272 | + params.dim // params.n_heads, |
| 273 | + params.max_seq_len * 2, |
| 274 | + params.rope_theta, |
| 275 | + ) |
| 276 | + |
| 277 | + @torch.inference_mode() |
| 278 | + def forward(self, tokens: torch.Tensor, start_pos: int): |
| 279 | + _bsz, seqlen = tokens.shape |
| 280 | + h = self.tok_embeddings(tokens) |
| 281 | + self.freqs_cis = self.freqs_cis.to(h.device) |
| 282 | + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] |
| 283 | + |
| 284 | + mask = None |
| 285 | + if seqlen > 1: |
| 286 | + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) |
| 287 | + |
| 288 | + mask = torch.triu(mask, diagonal=1) |
| 289 | + |
| 290 | + # When performing key-value caching, we compute the attention scores |
| 291 | + # only for the new sequence. Thus, the matrix of scores is of size |
| 292 | + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for |
| 293 | + # j > cache_len + i, since row i corresponds to token cache_len + i. |
| 294 | + mask = torch.hstack( |
| 295 | + [torch.zeros((seqlen, start_pos), device=tokens.device), mask] |
| 296 | + ).type_as(h) |
| 297 | + |
| 298 | + for layer in self.layers: |
| 299 | + h = layer(h, start_pos, freqs_cis, mask) |
| 300 | + h = self.norm(h) |
| 301 | + output = self.output(h).float() |
| 302 | + return output |
0 commit comments