1+ import torch
2+ import torch .nn .functional as F
3+ from torch import nn , einsum
4+
5+ from einops import rearrange
6+
7+
8+ class Residual (nn .Module ):
9+ def __init__ (self , fn ):
10+ super ().__init__ ()
11+ self .fn = fn
12+
13+ def forward (self , x , ** kwargs ):
14+ return self .fn (x , ** kwargs ) + x
15+
16+ class PreNorm (nn .Module ):
17+ def __init__ (self , dim , fn ):
18+ super ().__init__ ()
19+ self .norm = nn .LayerNorm (dim )
20+ self .fn = fn
21+
22+ def forward (self , x , ** kwargs ):
23+ return self .fn (self .norm (x ), ** kwargs )
24+
25+ # attention
26+
27+ class GEGLU (nn .Module ):
28+ def forward (self , x ):
29+ x , gates = x .chunk (2 , dim = - 1 )
30+ return x * F .gelu (gates )
31+
32+ class FeedForward (nn .Module ):
33+ def __init__ (self , dim , mult = 4 , dropout = 0. ):
34+ super ().__init__ ()
35+ self .net = nn .Sequential (
36+ nn .Linear (dim , dim * mult * 2 ),
37+ GEGLU (),
38+ nn .Dropout (dropout ),
39+ nn .Linear (dim * mult , dim )
40+ )
41+
42+ def forward (self , x , ** kwargs ):
43+ return self .net (x )
44+
45+ class Attention (nn .Module ):
46+ def __init__ (
47+ self ,
48+ dim ,
49+ heads = 8 ,
50+ dim_head = 16 ,
51+ dropout = 0.
52+ ):
53+ super ().__init__ ()
54+ inner_dim = dim_head * heads
55+ self .heads = heads
56+ self .scale = dim_head ** - 0.5
57+
58+ self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
59+ self .to_out = nn .Linear (inner_dim , dim )
60+
61+ self .dropout = nn .Dropout (dropout )
62+
63+ def forward (self , x ):
64+ h = self .heads
65+ q , k , v = self .to_qkv (x ).chunk (3 , dim = - 1 )
66+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), (q , k , v ))
67+ sim = einsum ('b h i d, b h j d -> b h i j' , q , k ) * self .scale
68+
69+ attn = sim .softmax (dim = - 1 )
70+ attn = self .dropout (attn )
71+
72+ out = einsum ('b h i j, b h j d -> b h i d' , attn , v )
73+ out = rearrange (out , 'b h n d -> b n (h d)' , h = h )
74+ return self .to_out (out )
75+
76+ # transformer
77+
78+ class Transformer (nn .Module ):
79+ def __init__ (self , num_tokens , dim , depth , heads , dim_head , attn_dropout , ff_dropout ):
80+ super ().__init__ ()
81+ self .embeds = nn .Embedding (num_tokens , dim )
82+ self .layers = nn .ModuleList ([])
83+
84+ for _ in range (depth ):
85+ self .layers .append (nn .ModuleList ([
86+ Residual (PreNorm (dim , Attention (dim , heads = heads , dim_head = dim_head , dropout = attn_dropout ))),
87+ Residual (PreNorm (dim , FeedForward (dim , dropout = ff_dropout ))),
88+ ]))
89+
90+ def forward (self , x ):
91+ x = self .embeds (x )
92+
93+ for attn , ff in self .layers :
94+ x = attn (x )
95+ x = ff (x )
96+
97+ return x
98+ # mlp
99+
100+ class MLP (nn .Module ):
101+ def __init__ (self , dims , act = None ):
102+ super ().__init__ ()
103+ dims_pairs = list (zip (dims [:- 1 ], dims [1 :]))
104+ layers = []
105+ for ind , (dim_in , dim_out ) in enumerate (dims_pairs ):
106+ is_last = ind >= (len (dims ) - 1 )
107+ linear = nn .Linear (dim_in , dim_out )
108+ layers .append (linear )
109+
110+ if is_last :
111+ continue
112+
113+ act = default (act , nn .ReLU ())
114+ layers .append (act )
115+
116+ self .mlp = nn .Sequential (* layers )
117+
118+ def forward (self , x ):
119+ return self .mlp (x )
0 commit comments