Skip to content

Commit 014e619

Browse files
committed
--added components
1 parent 1c8bdab commit 014e619

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch import nn, einsum
4+
5+
from einops import rearrange
6+
7+
8+
class Residual(nn.Module):
9+
def __init__(self, fn):
10+
super().__init__()
11+
self.fn = fn
12+
13+
def forward(self, x, **kwargs):
14+
return self.fn(x, **kwargs) + x
15+
16+
class PreNorm(nn.Module):
17+
def __init__(self, dim, fn):
18+
super().__init__()
19+
self.norm = nn.LayerNorm(dim)
20+
self.fn = fn
21+
22+
def forward(self, x, **kwargs):
23+
return self.fn(self.norm(x), **kwargs)
24+
25+
# attention
26+
27+
class GEGLU(nn.Module):
28+
def forward(self, x):
29+
x, gates = x.chunk(2, dim = -1)
30+
return x * F.gelu(gates)
31+
32+
class FeedForward(nn.Module):
33+
def __init__(self, dim, mult = 4, dropout = 0.):
34+
super().__init__()
35+
self.net = nn.Sequential(
36+
nn.Linear(dim, dim * mult * 2),
37+
GEGLU(),
38+
nn.Dropout(dropout),
39+
nn.Linear(dim * mult, dim)
40+
)
41+
42+
def forward(self, x, **kwargs):
43+
return self.net(x)
44+
45+
class Attention(nn.Module):
46+
def __init__(
47+
self,
48+
dim,
49+
heads = 8,
50+
dim_head = 16,
51+
dropout = 0.
52+
):
53+
super().__init__()
54+
inner_dim = dim_head * heads
55+
self.heads = heads
56+
self.scale = dim_head ** -0.5
57+
58+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
59+
self.to_out = nn.Linear(inner_dim, dim)
60+
61+
self.dropout = nn.Dropout(dropout)
62+
63+
def forward(self, x):
64+
h = self.heads
65+
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
66+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
67+
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
68+
69+
attn = sim.softmax(dim = -1)
70+
attn = self.dropout(attn)
71+
72+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
73+
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
74+
return self.to_out(out)
75+
76+
# transformer
77+
78+
class Transformer(nn.Module):
79+
def __init__(self, num_tokens, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
80+
super().__init__()
81+
self.embeds = nn.Embedding(num_tokens, dim)
82+
self.layers = nn.ModuleList([])
83+
84+
for _ in range(depth):
85+
self.layers.append(nn.ModuleList([
86+
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
87+
Residual(PreNorm(dim, FeedForward(dim, dropout = ff_dropout))),
88+
]))
89+
90+
def forward(self, x):
91+
x = self.embeds(x)
92+
93+
for attn, ff in self.layers:
94+
x = attn(x)
95+
x = ff(x)
96+
97+
return x
98+
# mlp
99+
100+
class MLP(nn.Module):
101+
def __init__(self, dims, act = None):
102+
super().__init__()
103+
dims_pairs = list(zip(dims[:-1], dims[1:]))
104+
layers = []
105+
for ind, (dim_in, dim_out) in enumerate(dims_pairs):
106+
is_last = ind >= (len(dims) - 1)
107+
linear = nn.Linear(dim_in, dim_out)
108+
layers.append(linear)
109+
110+
if is_last:
111+
continue
112+
113+
act = default(act, nn.ReLU())
114+
layers.append(act)
115+
116+
self.mlp = nn.Sequential(*layers)
117+
118+
def forward(self, x):
119+
return self.mlp(x)

0 commit comments

Comments
 (0)