@@ -134,6 +134,54 @@ namespace Flux {
134134 }
135135 };
136136
137+ struct MLP : public UnaryBlock {
138+ bool use_mlp_silu_act;
139+
140+ public:
141+ MLP (int64_t hidden_size, int64_t intermediate_size, bool use_mlp_silu_act = false , bool bias = false )
142+ : use_mlp_silu_act(use_mlp_silu_act) {
143+ int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1 ;
144+ blocks[" 0" ] = std::make_shared<Linear>(hidden_size, intermediate_size * mlp_mult_factor, bias);
145+ blocks[" 2" ] = std::make_shared<Linear>(intermediate_size, hidden_size, bias);
146+ }
147+
148+ struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * x) {
149+ auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks[" 0" ]);
150+ auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks[" 2" ]);
151+
152+ x = mlp_0->forward (ctx, x);
153+ if (use_mlp_silu_act) {
154+ x = ggml_ext_silu_act (ctx->ggml_ctx , x);
155+ } else {
156+ x = ggml_gelu_inplace (ctx->ggml_ctx , x);
157+ }
158+ x = mlp_2->forward (ctx, x);
159+ return x;
160+ }
161+ };
162+
163+ struct YakMLP : public UnaryBlock {
164+ public:
165+ YakMLP (int64_t hidden_size, int64_t intermediate_size, bool bias = true ) {
166+ blocks[" gate_proj" ] = std::make_shared<Linear>(hidden_size, intermediate_size, bias);
167+ blocks[" up_proj" ] = std::make_shared<Linear>(hidden_size, intermediate_size, bias);
168+ blocks[" down_proj" ] = std::make_shared<Linear>(intermediate_size, hidden_size, bias);
169+ }
170+
171+ struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * x) {
172+ auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks[" gate_proj" ]);
173+ auto up_proj = std::dynamic_pointer_cast<Linear>(blocks[" up_proj" ]);
174+ auto down_proj = std::dynamic_pointer_cast<Linear>(blocks[" down_proj" ]);
175+
176+ auto gate = gate_proj->forward (ctx, x);
177+ gate = ggml_silu_inplace (ctx->ggml_ctx , gate);
178+ x = up_proj->forward (ctx, x);
179+ x = ggml_mul (ctx->ggml_ctx , x, gate);
180+ x = down_proj->forward (ctx, x);
181+ return x;
182+ }
183+ };
184+
137185 struct ModulationOut {
138186 ggml_tensor* shift = nullptr ;
139187 ggml_tensor* scale = nullptr ;
@@ -199,7 +247,6 @@ namespace Flux {
199247 struct DoubleStreamBlock : public GGMLBlock {
200248 bool prune_mod;
201249 int idx = 0 ;
202- bool use_mlp_silu_act;
203250
204251 public:
205252 DoubleStreamBlock (int64_t hidden_size,
@@ -210,10 +257,10 @@ namespace Flux {
210257 bool prune_mod = false ,
211258 bool share_modulation = false ,
212259 bool mlp_proj_bias = true ,
260+ bool use_yak_mlp = false ,
213261 bool use_mlp_silu_act = false )
214- : idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
215- int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
216- int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1 ;
262+ : idx(idx), prune_mod(prune_mod) {
263+ int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
217264
218265 if (!prune_mod && !share_modulation) {
219266 blocks[" img_mod" ] = std::shared_ptr<GGMLBlock>(new Modulation (hidden_size, true ));
@@ -222,9 +269,11 @@ namespace Flux {
222269 blocks[" img_attn" ] = std::shared_ptr<GGMLBlock>(new SelfAttention (hidden_size, num_heads, qkv_bias, mlp_proj_bias));
223270
224271 blocks[" img_norm2" ] = std::shared_ptr<GGMLBlock>(new LayerNorm (hidden_size, 1e-6f , false ));
225- blocks[" img_mlp.0" ] = std::shared_ptr<GGMLBlock>(new Linear (hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
226- // img_mlp.1 is nn.GELU(approximate="tanh")
227- blocks[" img_mlp.2" ] = std::shared_ptr<GGMLBlock>(new Linear (mlp_hidden_dim, hidden_size, mlp_proj_bias));
272+ if (use_yak_mlp) {
273+ blocks[" img_mlp" ] = std::shared_ptr<GGMLBlock>(new YakMLP (hidden_size, mlp_hidden_dim, mlp_proj_bias));
274+ } else {
275+ blocks[" img_mlp" ] = std::shared_ptr<GGMLBlock>(new MLP (hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias));
276+ }
228277
229278 if (!prune_mod && !share_modulation) {
230279 blocks[" txt_mod" ] = std::shared_ptr<GGMLBlock>(new Modulation (hidden_size, true ));
@@ -233,9 +282,11 @@ namespace Flux {
233282 blocks[" txt_attn" ] = std::shared_ptr<GGMLBlock>(new SelfAttention (hidden_size, num_heads, qkv_bias, mlp_proj_bias));
234283
235284 blocks[" txt_norm2" ] = std::shared_ptr<GGMLBlock>(new LayerNorm (hidden_size, 1e-6f , false ));
236- blocks[" txt_mlp.0" ] = std::shared_ptr<GGMLBlock>(new Linear (hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
237- // img_mlp.1 is nn.GELU(approximate="tanh")
238- blocks[" txt_mlp.2" ] = std::shared_ptr<GGMLBlock>(new Linear (mlp_hidden_dim, hidden_size, mlp_proj_bias));
285+ if (use_yak_mlp) {
286+ blocks[" txt_mlp" ] = std::shared_ptr<GGMLBlock>(new YakMLP (hidden_size, mlp_hidden_dim, mlp_proj_bias));
287+ } else {
288+ blocks[" txt_mlp" ] = std::shared_ptr<GGMLBlock>(new MLP (hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias));
289+ }
239290 }
240291
241292 std::vector<ModulationOut> get_distil_img_mod (GGMLRunnerContext* ctx, struct ggml_tensor * vec) {
@@ -272,15 +323,13 @@ namespace Flux {
272323 auto img_attn = std::dynamic_pointer_cast<SelfAttention>(blocks[" img_attn" ]);
273324
274325 auto img_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks[" img_norm2" ]);
275- auto img_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks[" img_mlp.0" ]);
276- auto img_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks[" img_mlp.2" ]);
326+ auto img_mlp = std::dynamic_pointer_cast<UnaryBlock>(blocks[" img_mlp" ]);
277327
278328 auto txt_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks[" txt_norm1" ]);
279329 auto txt_attn = std::dynamic_pointer_cast<SelfAttention>(blocks[" txt_attn" ]);
280330
281331 auto txt_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks[" txt_norm2" ]);
282- auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks[" txt_mlp.0" ]);
283- auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks[" txt_mlp.2" ]);
332+ auto txt_mlp = std::dynamic_pointer_cast<UnaryBlock>(blocks[" txt_mlp" ]);
284333
285334 if (img_mods.empty ()) {
286335 if (prune_mod) {
@@ -348,27 +397,15 @@ namespace Flux {
348397 // calculate the img bloks
349398 img = ggml_add (ctx->ggml_ctx , img, ggml_mul (ctx->ggml_ctx , img_attn->post_attention (ctx, img_attn_out), img_mod1.gate ));
350399
351- auto img_mlp_out = img_mlp_0->forward (ctx, Flux::modulate (ctx->ggml_ctx , img_norm2->forward (ctx, img), img_mod2.shift , img_mod2.scale ));
352- if (use_mlp_silu_act) {
353- img_mlp_out = ggml_ext_silu_act (ctx->ggml_ctx , img_mlp_out);
354- } else {
355- img_mlp_out = ggml_gelu_inplace (ctx->ggml_ctx , img_mlp_out);
356- }
357- img_mlp_out = img_mlp_2->forward (ctx, img_mlp_out);
400+ auto img_mlp_out = img_mlp->forward (ctx, Flux::modulate (ctx->ggml_ctx , img_norm2->forward (ctx, img), img_mod2.shift , img_mod2.scale ));
358401
359402 img = ggml_add (ctx->ggml_ctx , img, ggml_mul (ctx->ggml_ctx , img_mlp_out, img_mod2.gate ));
360403
361404 // calculate the txt bloks
362405 txt = ggml_add (ctx->ggml_ctx , txt, ggml_mul (ctx->ggml_ctx , txt_attn->post_attention (ctx, txt_attn_out), txt_mod1.gate ));
363406
364- auto txt_mlp_out = txt_mlp_0->forward (ctx, Flux::modulate (ctx->ggml_ctx , txt_norm2->forward (ctx, txt), txt_mod2.shift , txt_mod2.scale ));
365- if (use_mlp_silu_act) {
366- txt_mlp_out = ggml_ext_silu_act (ctx->ggml_ctx , txt_mlp_out);
367- } else {
368- txt_mlp_out = ggml_gelu_inplace (ctx->ggml_ctx , txt_mlp_out);
369- }
370- txt_mlp_out = txt_mlp_2->forward (ctx, txt_mlp_out);
371- txt = ggml_add (ctx->ggml_ctx , txt, ggml_mul (ctx->ggml_ctx , txt_mlp_out, txt_mod2.gate ));
407+ auto txt_mlp_out = txt_mlp->forward (ctx, Flux::modulate (ctx->ggml_ctx , txt_norm2->forward (ctx, txt), txt_mod2.shift , txt_mod2.scale ));
408+ txt = ggml_add (ctx->ggml_ctx , txt, ggml_mul (ctx->ggml_ctx , txt_mlp_out, txt_mod2.gate ));
372409
373410 return {img, txt};
374411 }
@@ -381,6 +418,7 @@ namespace Flux {
381418 int64_t mlp_hidden_dim;
382419 bool prune_mod;
383420 int idx = 0 ;
421+ bool use_yak_mlp;
384422 bool use_mlp_silu_act;
385423 int64_t mlp_mult_factor;
386424
@@ -393,16 +431,17 @@ namespace Flux {
393431 bool prune_mod = false ,
394432 bool share_modulation = false ,
395433 bool mlp_proj_bias = true ,
434+ bool use_yak_mlp = false ,
396435 bool use_mlp_silu_act = false )
397- : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
436+ : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) {
398437 int64_t head_dim = hidden_size / num_heads;
399438 float scale = qk_scale;
400439 if (scale <= 0 .f ) {
401440 scale = 1 / sqrt ((float )head_dim);
402441 }
403442 mlp_hidden_dim = hidden_size * mlp_ratio;
404443 mlp_mult_factor = 1 ;
405- if (use_mlp_silu_act) {
444+ if (use_yak_mlp || use_mlp_silu_act) {
406445 mlp_mult_factor = 2 ;
407446 }
408447
@@ -481,7 +520,9 @@ namespace Flux {
481520 k = norm->key_norm (ctx, k);
482521 auto attn = Rope::attention (ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
483522
484- if (use_mlp_silu_act) {
523+ if (use_yak_mlp) {
524+ mlp = ggml_ext_silu_act (ctx->ggml_ctx , mlp, false );
525+ } else if (use_mlp_silu_act) {
485526 mlp = ggml_ext_silu_act (ctx->ggml_ctx , mlp);
486527 } else {
487528 mlp = ggml_gelu_inplace (ctx->ggml_ctx , mlp);
@@ -726,6 +767,8 @@ namespace Flux {
726767 int64_t in_dim = 64 ;
727768 bool disable_bias = false ;
728769 bool share_modulation = false ;
770+ bool semantic_txt_norm = false ;
771+ bool use_yak_mlp = false ;
729772 bool use_mlp_silu_act = false ;
730773 float ref_index_scale = 1 .f;
731774 ChromaRadianceParams chroma_radiance_params;
@@ -759,6 +802,9 @@ namespace Flux {
759802 blocks[" guidance_in" ] = std::make_shared<MLPEmbedder>(256 , params.hidden_size , !params.disable_bias );
760803 }
761804 }
805+ if (params.semantic_txt_norm ) {
806+ blocks[" txt_norm" ] = std::make_shared<RMSNorm>(params.context_in_dim );
807+ }
762808 blocks[" txt_in" ] = std::make_shared<Linear>(params.context_in_dim , params.hidden_size , !params.disable_bias );
763809
764810 for (int i = 0 ; i < params.depth ; i++) {
@@ -770,6 +816,7 @@ namespace Flux {
770816 params.is_chroma ,
771817 params.share_modulation ,
772818 !params.disable_bias ,
819+ params.use_yak_mlp ,
773820 params.use_mlp_silu_act );
774821 }
775822
@@ -782,6 +829,7 @@ namespace Flux {
782829 params.is_chroma ,
783830 params.share_modulation ,
784831 !params.disable_bias ,
832+ params.use_yak_mlp ,
785833 params.use_mlp_silu_act );
786834 }
787835
@@ -948,6 +996,12 @@ namespace Flux {
948996 ss_mods = single_stream_modulation->forward (ctx, vec);
949997 }
950998
999+ if (params.semantic_txt_norm ) {
1000+ auto semantic_txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks[" txt_norm" ]);
1001+
1002+ txt = semantic_txt_norm->forward (ctx, txt);
1003+ }
1004+
9511005 txt = txt_in->forward (ctx, txt);
9521006
9531007 for (int i = 0 ; i < params.depth ; i++) {
@@ -1206,6 +1260,11 @@ namespace Flux {
12061260 } else if (version == VERSION_CHROMA_RADIANCE) {
12071261 flux_params.in_channels = 3 ;
12081262 flux_params.patch_size = 16 ;
1263+ } else if (version == VERSION_OVIS_IMAGE) {
1264+ flux_params.semantic_txt_norm = true ;
1265+ flux_params.use_yak_mlp = true ;
1266+ flux_params.context_in_dim = 2048 ;
1267+ flux_params.vec_in_dim = 0 ;
12091268 } else if (sd_version_is_flux2 (version)) {
12101269 flux_params.context_in_dim = 15360 ;
12111270 flux_params.in_channels = 128 ;
@@ -1364,13 +1423,22 @@ namespace Flux {
13641423 ref_latents[i] = to_backend (ref_latents[i]);
13651424 }
13661425
1426+ std::set<int > txt_arange_dims;
1427+ if (sd_version_is_flux2 (version)) {
1428+ txt_arange_dims = {3 };
1429+ increase_ref_index = true ;
1430+ } else if (version == VERSION_OVIS_IMAGE) {
1431+ txt_arange_dims = {1 , 2 };
1432+ }
1433+
13671434 pe_vec = Rope::gen_flux_pe (x->ne [1 ],
13681435 x->ne [0 ],
13691436 flux_params.patch_size ,
13701437 x->ne [3 ],
13711438 context->ne [1 ],
1439+ txt_arange_dims,
13721440 ref_latents,
1373- sd_version_is_flux2 (version) ? true : increase_ref_index,
1441+ increase_ref_index,
13741442 flux_params.ref_index_scale ,
13751443 flux_params.theta ,
13761444 flux_params.axes_dim );
0 commit comments