|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +from torch import nn, einsum |
| 4 | + |
| 5 | +from einops import rearrange |
| 6 | + |
| 7 | + |
| 8 | +class Residual(nn.Module): |
| 9 | + def __init__(self, fn): |
| 10 | + super().__init__() |
| 11 | + self.fn = fn |
| 12 | + |
| 13 | + def forward(self, x, **kwargs): |
| 14 | + return self.fn(x, **kwargs) + x |
| 15 | + |
| 16 | + |
| 17 | +# https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/feed_forward.py |
| 18 | +class PositionWiseFeedForward(nn.Module): |
| 19 | + """ |
| 20 | + title: Position-wise Feed-Forward Network (FFN) |
| 21 | + summary: Documented reusable implementation of the position wise feedforward network. |
| 22 | +
|
| 23 | + # Position-wise Feed-Forward Network (FFN) |
| 24 | + This is a [PyTorch](https://pytorch.org) implementation |
| 25 | + of position-wise feedforward network used in transformer. |
| 26 | + FFN consists of two fully connected layers. |
| 27 | + Number of dimensions in the hidden layer $d_{ff}$, is generally set to around |
| 28 | + four times that of the token embedding $d_{model}$. |
| 29 | + So it is sometime also called the expand-and-contract network. |
| 30 | + There is an activation at the hidden layer, which is |
| 31 | + usually set to ReLU (Rectified Linear Unit) activation, $$\max(0, x)$$ |
| 32 | + That is, the FFN function is, |
| 33 | + $$FFN(x, W_1, W_2, b_1, b_2) = \max(0, x W_1 + b_1) W_2 + b_2$$ |
| 34 | + where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters. |
| 35 | + Sometimes the |
| 36 | + GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. |
| 37 | + $$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$ |
| 38 | + ### Gated Linear Units |
| 39 | + This is a generic implementation that supports different variants including |
| 40 | + [Gated Linear Units](https://arxiv.org/abs/2002.05202) (GLU). |
| 41 | + We have also implemented experiments on these: |
| 42 | + * [experiment that uses `labml.configs`](glu_variants/experiment.html) |
| 43 | + * [simpler version from scratch](glu_variants/simple.html) |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__(self, d_model: int, d_ff: int, |
| 47 | + dropout: float = 0.1, |
| 48 | + activation=nn.ReLU(), |
| 49 | + is_gated: bool = False, |
| 50 | + bias1: bool = True, |
| 51 | + bias2: bool = True, |
| 52 | + bias_gate: bool = True): |
| 53 | + """ |
| 54 | + * `d_model` is the number of features in a token embedding |
| 55 | + * `d_ff` is the number of features in the hidden layer of the FFN |
| 56 | + * `dropout` is dropout probability for the hidden layer |
| 57 | + * `is_gated` specifies whether the hidden layer is gated |
| 58 | + * `bias1` specified whether the first fully connected layer should have a learnable bias |
| 59 | + * `bias2` specified whether the second fully connected layer should have a learnable bias |
| 60 | + * `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias |
| 61 | + """ |
| 62 | + super().__init__() |
| 63 | + # Layer one parameterized by weight $W_1$ and bias $b_1$ |
| 64 | + self.layer1 = nn.Linear(d_model, d_ff, bias=bias1) |
| 65 | + # Layer one parameterized by weight $W_1$ and bias $b_1$ |
| 66 | + self.layer2 = nn.Linear(d_ff, d_model, bias=bias2) |
| 67 | + # Hidden layer dropout |
| 68 | + self.dropout = nn.Dropout(dropout) |
| 69 | + # Activation function $f$ |
| 70 | + self.activation = activation |
| 71 | + # Whether there is a gate |
| 72 | + self.is_gated = is_gated |
| 73 | + if is_gated: |
| 74 | + # If there is a gate the linear layer to transform inputs to |
| 75 | + # be multiplied by the gate, parameterized by weight $V$ and bias $c$ |
| 76 | + self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate) |
| 77 | + |
| 78 | + def forward(self, x: torch.Tensor): |
| 79 | + # $f(x W_1 + b_1)$ |
| 80 | + g = self.activation(self.layer1(x)) |
| 81 | + # If gated, $f(x W_1 + b_1) \otimes (x V + b) $ |
| 82 | + if self.is_gated: |
| 83 | + x = g * self.linear_v(x) |
| 84 | + # Otherwise |
| 85 | + else: |
| 86 | + x = g |
| 87 | + # Apply dropout |
| 88 | + x = self.dropout(x) |
| 89 | + # $(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$ |
| 90 | + # depending on whether it is gated |
| 91 | + return self.layer2(x) |
| 92 | + |
| 93 | +# GLU Variants Improve Transformer https://arxiv.org/pdf/2002.05202.pdf |
| 94 | +class GEGLU(nn.Module): |
| 95 | + def __init__(self, d_model: int, d_ff: int, |
| 96 | + dropout: float = 0.1): |
| 97 | + super().__init__() |
| 98 | + self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.GELU(), True, False, False, False) |
| 99 | + |
| 100 | + def forward(self, x: torch.Tensor): |
| 101 | + return self.ffn(x) |
| 102 | + |
| 103 | +class ReGLU(nn.Module): |
| 104 | + def __init__(self, d_model: int, d_ff: int, |
| 105 | + dropout: float = 0.1): |
| 106 | + super().__init__() |
| 107 | + self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.ReLU(), True, False, False, False) |
| 108 | + |
| 109 | + def forward(self, x: torch.Tensor): |
| 110 | + return self.ffn(x) |
| 111 | + |
| 112 | +class SwiGLU(nn.Module): |
| 113 | + def __init__(self, d_model: int, d_ff: int, |
| 114 | + dropout: float = 0.1): |
| 115 | + super().__init__() |
| 116 | + self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.SiLU(), True, False, False, False) |
| 117 | + |
| 118 | + def forward(self, x: torch.Tensor): |
| 119 | + return self.ffn(x) |
| 120 | + |
0 commit comments