@@ -101,17 +101,20 @@ static const Node* process_let(Context* ctx, const Node* node) {
101101 const PrimOp * oprim_op = & old_instruction -> payload .prim_op ;
102102 switch (oprim_op -> op ) {
103103 case get_stack_pointer_op : {
104+ assert (ctx -> stack );
104105 BodyBuilder * bb = begin_body (a );
105106 const Node * sp = gen_load (bb , ctx -> stack_pointer );
106107 return finish_body (bb , let (a , quote_helper (a , singleton (sp )), tail ));
107108 }
108109 case set_stack_pointer_op : {
110+ assert (ctx -> stack );
109111 BodyBuilder * bb = begin_body (a );
110112 const Node * val = rewrite_node (& ctx -> rewriter , oprim_op -> operands .nodes [0 ]);
111113 gen_store (bb , ctx -> stack_pointer , val );
112114 return finish_body (bb , let (a , quote_helper (a , empty (a )), tail ));
113115 }
114116 case get_stack_base_op : {
117+ assert (ctx -> stack );
115118 BodyBuilder * bb = begin_body (a );
116119 const Node * stack_pointer = ctx -> stack_pointer ;
117120 const Node * stack_size = gen_load (bb , stack_pointer );
@@ -126,6 +129,7 @@ static const Node* process_let(Context* ctx, const Node* node) {
126129 }
127130 case push_stack_op :
128131 case pop_stack_op : {
132+ assert (ctx -> stack );
129133 BodyBuilder * bb = begin_body (a );
130134 const Type * element_type = rewrite_node (& ctx -> rewriter , first (oprim_op -> type_arguments ));
131135
@@ -161,8 +165,10 @@ static const Node* process_node(Context* ctx, const Node* old) {
161165 // Make sure to zero-init the stack pointers
162166 // TODO isn't this redundant with thoose things having an initial value already ?
163167 // is this an old forgotten workaround ?
164- const Node * stack_pointer = ctx -> stack_pointer ;
165- gen_store (bb , stack_pointer , uint32_literal (a , 0 ));
168+ if (ctx -> stack ) {
169+ const Node * stack_pointer = ctx -> stack_pointer ;
170+ gen_store (bb , stack_pointer , uint32_literal (a , 0 ));
171+ }
166172 new -> payload .fun .body = finish_body (bb , rewrite_node (& ctx -> rewriter , old -> payload .fun .body ));
167173 return new ;
168174 }
@@ -181,34 +187,36 @@ Module* lower_stack(SHADY_UNUSED const CompilerConfig* config, Module* src) {
181187 IrArena * a = new_ir_arena (aconfig );
182188 Module * dst = new_module (a , get_module_name (src ));
183189
184- const Type * stack_base_element = uint8_type (a );
185- const Type * stack_arr_type = arr_type (a , (ArrType ) {
186- .element_type = stack_base_element ,
187- .size = uint32_literal (a , config -> per_thread_stack_size ),
188- });
189- const Type * stack_counter_t = uint32_type (a );
190-
191- Nodes annotations = mk_nodes (a , annotation (a , (Annotation ) { .name = "Generated" }));
192-
193- // Arrays for the stacks
194- Node * stack_decl = global_var (dst , annotations , stack_arr_type , "stack" , AsPrivate );
195-
196- // Pointers into those arrays
197- Node * stack_ptr_decl = global_var (dst , append_nodes (a , annotations , annotation (a , (Annotation ) { .name = "Logical" })), stack_counter_t , "stack_ptr" , AsPrivate );
198- stack_ptr_decl -> payload .global_variable .init = uint32_literal (a , 0 );
199-
200190 Context ctx = {
201191 .rewriter = create_rewriter (src , dst , (RewriteNodeFn ) process_node ),
202192
203193 .config = config ,
204194
205195 .push = new_dict (const Node * , Node * , (HashFn ) hash_node , (CmpFn ) compare_node ),
206196 .pop = new_dict (const Node * , Node * , (HashFn ) hash_node , (CmpFn ) compare_node ),
207-
208- .stack = ref_decl_helper (a , stack_decl ),
209- .stack_pointer = ref_decl_helper (a , stack_ptr_decl ),
210197 };
211198
199+ if (config -> per_thread_stack_size > 0 ) {
200+ const Type * stack_base_element = uint8_type (a );
201+ const Type * stack_arr_type = arr_type (a , (ArrType ) {
202+ .element_type = stack_base_element ,
203+ .size = uint32_literal (a , config -> per_thread_stack_size ),
204+ });
205+ const Type * stack_counter_t = uint32_type (a );
206+
207+ Nodes annotations = mk_nodes (a , annotation (a , (Annotation ) { .name = "Generated" }));
208+
209+ // Arrays for the stacks
210+ Node * stack_decl = global_var (dst , annotations , stack_arr_type , "stack" , AsPrivate );
211+
212+ // Pointers into those arrays
213+ Node * stack_ptr_decl = global_var (dst , append_nodes (a , annotations , annotation (a , (Annotation ) { .name = "Logical" })), stack_counter_t , "stack_ptr" , AsPrivate );
214+ stack_ptr_decl -> payload .global_variable .init = uint32_literal (a , 0 );
215+
216+ ctx .stack = ref_decl_helper (a , stack_decl );
217+ ctx .stack_pointer = ref_decl_helper (a , stack_ptr_decl );
218+ }
219+
212220 rewrite_module (& ctx .rewriter );
213221 destroy_rewriter (& ctx .rewriter );
214222
0 commit comments