@@ -1039,6 +1039,7 @@ def __init__(
10391039 fine_attn_depth = 2 ,
10401040 fine_attn_dim_head = 32 ,
10411041 fine_attn_heads = 8 ,
1042+ fine_cross_attend_text = False ,
10421043 pad_id = - 1 ,
10431044 num_sos_tokens = None ,
10441045 condition_on_text = False ,
@@ -1137,6 +1138,8 @@ def __init__(
11371138
11381139 # decoding the vertices, 2-stage hierarchy
11391140
1141+ self .fine_cross_attend_text = condition_on_text and fine_cross_attend_text
1142+
11401143 self .fine_decoder = Decoder (
11411144 dim = dim_fine ,
11421145 depth = fine_attn_depth ,
@@ -1145,6 +1148,9 @@ def __init__(
11451148 attn_flash = flash_attn ,
11461149 attn_dropout = dropout ,
11471150 ff_dropout = dropout ,
1151+ cross_attend = self .fine_cross_attend_text ,
1152+ cross_attn_dim_context = cross_attn_dim_context ,
1153+ cross_attn_num_mem_kv = cross_attn_num_mem_kv ,
11481154 ** attn_kwargs
11491155 )
11501156
@@ -1512,8 +1518,17 @@ def forward_on_codes(
15121518 if exists (fine_cache ):
15131519 for attn_intermediate in fine_cache .attn_intermediates :
15141520 ck , cv = attn_intermediate .cached_kv
1515- ck , cv = map (lambda t : rearrange (t , '(b nf) ... -> b nf ...' , b = batch ), (ck , cv ))
1516- ck , cv = map (lambda t : t [:, - 1 , :, :curr_vertex_pos ], (ck , cv ))
1521+ ck , cv = [rearrange (t , '(b nf) ... -> b nf ...' , b = batch ) for t in (ck , cv )]
1522+
1523+ # when operating on the cached key / values, treat self attention and cross attention differently
1524+
1525+ layer_type = attn_intermediate .layer_type
1526+
1527+ if layer_type == 'a' :
1528+ ck , cv = [t [:, - 1 , :, :curr_vertex_pos ] for t in (ck , cv )]
1529+ elif layer_type == 'c' :
1530+ ck , cv = [t [:, - 1 , ...] for t in (ck , cv )]
1531+
15171532 attn_intermediate .cached_kv = (ck , cv )
15181533
15191534 num_faces = fine_vertex_codes .shape [1 ]
@@ -1524,9 +1539,25 @@ def forward_on_codes(
15241539 if one_face :
15251540 fine_vertex_codes = fine_vertex_codes [:, :(curr_vertex_pos + 1 )]
15261541
1542+ # handle maybe cross attention conditioning of fine transformer with text
1543+
1544+ fine_attn_context_kwargs = dict ()
1545+
1546+ if self .fine_cross_attend_text :
1547+ repeat_batch = fine_vertex_codes .shape [0 ] // text_embed .shape [0 ]
1548+
1549+ text_embed = repeat (text_embed , 'b ... -> (b r) ...' , r = repeat_batch )
1550+ text_mask = repeat (text_mask , 'b ... -> (b r) ...' , r = repeat_batch )
1551+
1552+ fine_attn_context_kwargs = dict (
1553+ context = text_embed ,
1554+ context_mask = text_mask
1555+ )
1556+
15271557 attended_vertex_codes , fine_cache = self .fine_decoder (
15281558 fine_vertex_codes ,
15291559 cache = fine_cache ,
1560+ ** fine_attn_context_kwargs ,
15301561 return_hiddens = True
15311562 )
15321563
0 commit comments