1515#include "../analysis/free_variables.h"
1616#include "../analysis/uses.h"
1717#include "../analysis/leak.h"
18+ #include "../analysis/verify.h"
1819
1920#include <assert.h>
2021#include <string.h>
@@ -31,6 +32,8 @@ typedef struct Context_ {
3132 struct Dict * lifted ;
3233 bool disable_lowering ;
3334 const CompilerConfig * config ;
35+
36+ bool * todo ;
3437} Context ;
3538
3639static const Node * process_node (Context * ctx , const Node * node );
@@ -66,6 +69,20 @@ static const Node* add_spill_instrs(Context* ctx, BodyBuilder* builder, struct L
6669 return sp ;
6770}
6871
72+ static void add_to_recover_context (struct List * recover_context , struct Dict * set , const Node * except ) {
73+ Nodes params = get_abstraction_params (except );
74+ size_t i = 0 ;
75+ const Node * item ;
76+ while (dict_iter (set , & i , & item , NULL )) {
77+ for (size_t j = 0 ; j < params .count ; j ++ ) {
78+ if (item == params .nodes [j ])
79+ goto skip ;
80+ }
81+ append_list (const Node * , recover_context , item );
82+ skip :;
83+ }
84+ }
85+
6986static LiftedCont * lambda_lift (Context * ctx , const Node * cont , String given_name ) {
7087 assert (is_basic_block (cont ) || is_case (cont ));
7188 LiftedCont * * found = find_value_dict (const Node * , LiftedCont * , ctx -> lifted , cont );
@@ -82,20 +99,19 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
8299 CFNode * cf_node = scope_lookup (ctx -> scope , cont );
83100 CFNodeVariables * node_vars = * find_value_dict (CFNode * , CFNodeVariables * , ctx -> scope_vars , cf_node );
84101 struct List * recover_context = new_list (const Node * );
85- size_t recover_context_size = entries_count_dict (node_vars -> free_set );
86-
87- {
88- debugv_print ("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): " , name , entries_count_dict (node_vars -> free_set ));
89- size_t i = 0 ;
90- const Node * item ;
91- while (dict_iter (node_vars -> free_set , & i , & item , NULL )) {
92- append_list (const Node * , recover_context , item );
93- debugv_print (get_value_name_safe (item ));
94- if (i + 1 < recover_context_size )
95- debugv_print (", " );
96- }
97- debugv_print ("\n" );
102+
103+ // add_to_recover_context(recover_context, node_vars->free_set, cont);
104+ add_to_recover_context (recover_context , node_vars -> bound_set , cont );
105+ size_t recover_context_size = entries_count_list (recover_context );
106+
107+ debugv_print ("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): " , name , recover_context_size );
108+ for (size_t i = 0 ; i < recover_context_size ; i ++ ) {
109+ const Node * item = read_list (const Node * , recover_context )[i ];
110+ debugv_print (get_value_name_safe (item ));
111+ if (i + 1 < recover_context_size )
112+ debugv_print (", " );
98113 }
114+ debugv_print ("\n" );
99115
100116 // Create and register new parameters for the lifted continuation
101117 Nodes new_params = recreate_variables (& ctx -> rewriter , oparams );
@@ -106,7 +122,13 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
106122 insert_dict (const Node * , LiftedCont * , ctx -> lifted , cont , lifted_cont );
107123
108124 Context lifting_ctx = * ctx ;
109- lifting_ctx .rewriter = create_rewriter (ctx -> rewriter .src_module , ctx -> rewriter .dst_module , (RewriteNodeFn ) process_node );
125+ // struct Dict* old_map = lifting_ctx.rewriter.map;
126+ // lifting_ctx.rewriter.map = clone_dict(lifting_ctx.rewriter.map);
127+
128+ // lifting_ctx.rewriter = create_rewriter(ctx->rewriter.src_module, ctx->rewriter.dst_module, (RewriteNodeFn) process_node);
129+ // lifting_ctx.rewriter.decls_map = NULL;
130+ lifting_ctx .rewriter .map = new_dict (const Node * , Node * , (HashFn ) hash_node , (CmpFn ) compare_node );
131+ lifting_ctx .rewriter .parent = & ctx -> rewriter ;
110132 register_processed_list (& lifting_ctx .rewriter , oparams , new_params );
111133
112134 const Node * payload = var (a , qualified_type_helper (uint32_type (a ), false), "sp" );
@@ -140,6 +162,7 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
140162 const Node * substituted = rewrite_node (& lifting_ctx .rewriter , obody );
141163 //destroy_dict(lifting_ctx.rewriter.processed);
142164 destroy_rewriter (& lifting_ctx .rewriter );
165+ // lifting_ctx.rewriter.map = old_map;
143166
144167 assert (is_terminator (substituted ));
145168 new_fn -> payload .fun .body = finish_body (bb , substituted );
@@ -151,27 +174,18 @@ static const Node* process_node(Context* ctx, const Node* node) {
151174 const Node * found = search_processed (& ctx -> rewriter , node );
152175 if (found ) return found ;
153176
154- // TODO: share this code
155- if (is_declaration (node )) {
156- String name = get_declaration_name (node );
157- Nodes decls = get_module_declarations (ctx -> rewriter .dst_module );
158- for (size_t i = 0 ; i < decls .count ; i ++ ) {
159- if (strcmp (get_declaration_name (decls .nodes [i ]), name ) == 0 )
160- return decls .nodes [i ];
161- }
162- }
163-
164177 IrArena * a = ctx -> rewriter .dst_arena ;
165178
166- if (ctx -> disable_lowering )
167- return recreate_node_identity (& ctx -> rewriter , node );
168-
169- switch (node -> tag ) {
179+ switch (is_declaration (node )) {
170180 case Function_TAG : {
181+ while (ctx -> rewriter .parent )
182+ ctx = (Context * ) ctx -> rewriter .parent ;
183+
171184 Context fn_ctx = * ctx ;
172185 fn_ctx .scope = new_scope (node );
173186 fn_ctx .scope_uses = create_uses_map (node , (NcDeclaration | NcType ));
174187 fn_ctx .scope_vars = compute_scope_variables_map (fn_ctx .scope );
188+ fn_ctx .disable_lowering = lookup_annotation (node , "Internal" );
175189 ctx = & fn_ctx ;
176190
177191 Node * new = recreate_decl_header_identity (& ctx -> rewriter , node );
@@ -182,51 +196,88 @@ static const Node* process_node(Context* ctx, const Node* node) {
182196 destroy_scope (ctx -> scope );
183197 return new ;
184198 }
199+ default :
200+ break ;
201+ }
202+
203+ if (ctx -> disable_lowering )
204+ return recreate_node_identity (& ctx -> rewriter , node );
205+
206+ switch (node -> tag ) {
185207 case Let_TAG : {
186208 const Node * oinstruction = get_let_instruction (node );
187209 if (oinstruction -> tag == Control_TAG ) {
188210 const Node * oinside = oinstruction -> payload .control .inside ;
189211 assert (is_case (oinside ));
190212 if (!is_control_static (ctx -> scope_uses , oinstruction ) || ctx -> config -> hacks .force_join_point_lifting ) {
213+ * ctx -> todo = true;
214+
191215 const Node * otail = get_let_tail (node );
192216 BodyBuilder * bb = begin_body (a );
193217 LiftedCont * lifted_tail = lambda_lift (ctx , otail , unique_name (a , format_string_arena (a -> arena , "post_control_%s" , get_abstraction_name (ctx -> scope -> entry -> node ))));
194218 const Node * sp = add_spill_instrs (ctx , bb , lifted_tail -> save_values );
195219 const Node * tail_ptr = fn_addr_helper (a , lifted_tail -> lifted_fn );
196220
197221 const Node * jp = gen_primop_e (bb , create_joint_point_op , rewrite_nodes (& ctx -> rewriter , oinstruction -> payload .control .yield_types ), mk_nodes (a , tail_ptr , sp ));
222+ // dumbass hack
223+ jp = gen_primop_e (bb , subgroup_assume_uniform_op , empty (a ), singleton (jp ));
198224
199225 return finish_body (bb , let (a , quote_helper (a , singleton (jp )), rewrite_node (& ctx -> rewriter , oinside )));
200226 }
201227 }
202-
203- return recreate_node_identity (& ctx -> rewriter , node );
228+ break ;
204229 }
205- default : return recreate_node_identity ( & ctx -> rewriter , node ) ;
230+ default : break ;
206231 }
232+ return recreate_node_identity (& ctx -> rewriter , node );
207233}
208234
209235Module * lift_indirect_targets (const CompilerConfig * config , Module * src ) {
210236 ArenaConfig aconfig = get_arena_config (get_module_arena (src ));
237+ IrArena * a = NULL ;
238+ Module * dst ;
239+
240+ int round = 0 ;
241+ while (true) {
242+ debugv_print ("lift_indirect_target: round %d\n" , round ++ );
243+ IrArena * oa = a ;
244+ a = new_ir_arena (aconfig );
245+ dst = new_module (a , get_module_name (src ));
246+ bool todo = false;
247+ Context ctx = {
248+ .rewriter = create_rewriter (src , dst , (RewriteNodeFn ) process_node ),
249+ .lifted = new_dict (const Node * , LiftedCont * , (HashFn ) hash_node , (CmpFn ) compare_node ),
250+ .config = config ,
251+
252+ .todo = & todo
253+ };
254+
255+ rewrite_module (& ctx .rewriter );
256+
257+ size_t iter = 0 ;
258+ LiftedCont * lifted_cont ;
259+ while (dict_iter (ctx .lifted , & iter , NULL , & lifted_cont )) {
260+ destroy_list (lifted_cont -> save_values );
261+ free (lifted_cont );
262+ }
263+ destroy_dict (ctx .lifted );
264+ destroy_rewriter (& ctx .rewriter );
265+ log_module (DEBUGVV , config , dst );
266+ verify_module (config , dst );
267+ src = dst ;
268+ if (oa )
269+ destroy_ir_arena (oa );
270+ if (!todo ) {
271+ break ;
272+ }
273+ }
274+
211275 // this will be safe now since we won't lift any more code after this pass
212276 aconfig .optimisations .weaken_non_leaking_allocas = true;
213- IrArena * a = new_ir_arena (aconfig );
214- Module * dst = new_module (a , get_module_name (src ));
215- Context ctx = {
216- .rewriter = create_rewriter (src , dst , (RewriteNodeFn ) process_node ),
217- .lifted = new_dict (const Node * , LiftedCont * , (HashFn ) hash_node , (CmpFn ) compare_node ),
218- .config = config ,
219- };
220-
221- rewrite_module (& ctx .rewriter );
222-
223- size_t iter = 0 ;
224- LiftedCont * lifted_cont ;
225- while (dict_iter (ctx .lifted , & iter , NULL , & lifted_cont )) {
226- destroy_list (lifted_cont -> save_values );
227- free (lifted_cont );
228- }
229- destroy_dict (ctx .lifted );
230- destroy_rewriter (& ctx .rewriter );
277+ IrArena * a2 = new_ir_arena (aconfig );
278+ dst = new_module (a2 , get_module_name (src ));
279+ Rewriter r = create_importer (src , dst );
280+ rewrite_module (& r );
281+ destroy_ir_arena (a );
231282 return dst ;
232283}
0 commit comments