@@ -311,20 +311,14 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
311311 return {finder.reads , finder.writes };
312312}
313313
314- struct ScheduleTreeAndDomain {
315- ScheduleTreeUPtr tree;
316- isl::union_set domain;
317- };
318-
319314/*
320- * Helper function for extracting a schedule tree from a Halide Stmt,
315+ * Helper function for extracting a schedule from a Halide Stmt,
321316 * recursively descending over the Stmt.
322317 * "s" is the current position in the recursive descent.
323318 * "set" describes the bounds on the outer loop iterators.
324319 * "outer" contains the names of the outer loop iterators
325320 * from outermost to innermost.
326- * Return the schedule tree corresponding to the subtree at "s",
327- * along with a separated out domain.
321+ * Return the schedule corresponding to the subtree at "s".
328322 *
329323 * "reads" and "writes" collect the accesses found along the way.
330324 * "accesses" collects the mapping from Call (for the reads) and Provide nodes
@@ -334,7 +328,7 @@ struct ScheduleTreeAndDomain {
334328 * "iterators" collects the mapping from instance set tuple identifiers
335329 * to the corresponding outer loop iterator names, from outermost to innermost.
336330 */
337- ScheduleTreeAndDomain makeScheduleTreeHelper (
331+ isl::schedule makeScheduleTreeHelper (
338332 const Stmt& s,
339333 isl::set set,
340334 std::vector<std::string>& outer,
@@ -343,7 +337,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
343337 AccessMap* accesses,
344338 StatementMap* statements,
345339 IteratorMap* iterators) {
346- ScheduleTreeAndDomain result ;
340+ isl::schedule schedule ;
347341 if (auto op = s.as <For>()) {
348342 // Add one additional dimension to our set of loop variables
349343 int thisLoopIdx = set.dim (isl::dim_type::set);
@@ -397,7 +391,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
397391 // dimension. The spaces may be different, but they'll all have
398392 // this loop var at the same index.
399393 isl::multi_union_pw_aff mupa;
400- body.domain .foreach_set ([&](isl::set s) {
394+ body.get_domain () .foreach_set ([&](isl::set s) {
401395 isl::aff loopVar (
402396 isl::local_space (s.get_space ()), isl::dim_type::set, thisLoopIdx);
403397 if (mupa) {
@@ -407,58 +401,20 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
407401 }
408402 });
409403
410- if (body.tree ) {
411- result.tree = ScheduleTree::makeBand (mupa, std::move (body.tree ));
412- } else {
413- result.tree = ScheduleTree::makeBand (mupa);
414- }
415- result.domain = body.domain ;
404+ schedule = body.insert_partial_schedule (mupa);
416405 } else if (auto op = s.as <Halide::Internal::Block>()) {
417- // Flatten a nested block. Halide Block statements always nest
418- // rightwards. Flattening it is not strictly necessary, but it
419- // keeps things uniform with the PET lowering path.
420406 std::vector<Stmt> stmts;
421407 stmts.push_back (op->first );
422408 stmts.push_back (op->rest );
423- while (const Halide::Internal::Block* b =
424- stmts.back ().as <Halide::Internal::Block>()) {
425- Stmt f = b->first ;
426- Stmt r = b->rest ;
427- stmts.pop_back ();
428- stmts.push_back (f);
429- stmts.push_back (r);
430- }
431409
432- // Build a schedule tree for each member of the block, then set up
433- // appropriate filters that state which statements lie in which
434- // children.
435- std::vector<ScheduleTreeUPtr> trees;
410+ // Build a schedule tree for both members of the block and
411+ // combine them in a sequence.
412+ std::vector<isl::schedule> schedules;
436413 for (Stmt s : stmts) {
437- auto mem = makeScheduleTreeHelper (
438- s, set, outer, reads, writes, accesses, statements, iterators);
439- ScheduleTreeUPtr filter;
440- if (mem.tree ) {
441- // No statement instances are shared between the blocks, so we
442- // can drop the constraints on the spaces. This makes the
443- // schedule tree slightly simpler.
444- filter = ScheduleTree::makeFilter (
445- mem.domain .universe (), std::move (mem.tree ));
446- } else {
447- filter = ScheduleTree::makeFilter (mem.domain .universe ());
448- }
449- if (result.domain ) {
450- result.domain = result.domain .unite (mem.domain );
451- } else {
452- result.domain = mem.domain ;
453- }
454- trees.push_back (std::move (filter));
455- }
456- CHECK_GE (trees.size (), 1 );
457-
458- result.tree = ScheduleTree::makeSequence (std::move (trees[0 ]));
459- for (size_t i = 1 ; i < trees.size (); i++) {
460- result.tree ->appendChild (std::move (trees[i]));
414+ schedules.push_back (makeScheduleTreeHelper (
415+ s, set, outer, reads, writes, accesses, statements, iterators));
461416 }
417+ schedule = schedules[0 ].sequence (schedules[1 ]);
462418
463419 } else if (auto op = s.as <Provide>()) {
464420 // Make an ID for this leaf statement. This *is* semantically
@@ -469,7 +425,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
469425 statements->emplace (id, op);
470426 iterators->emplace (id, outer);
471427 isl::set domain = set.set_tuple_id (id);
472- result. domain = domain;
428+ schedule = isl::schedule::from_domain ( domain) ;
473429
474430 isl::union_map newReads, newWrites;
475431 std::tie (newReads, newWrites) =
@@ -481,7 +437,7 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
481437 } else {
482438 LOG (FATAL) << " Unhandled Halide stmt: " << s;
483439 }
484- return result ;
440+ return schedule ;
485441};
486442
487443ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
@@ -491,7 +447,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
491447
492448 // Walk the IR building a schedule tree
493449 std::vector<std::string> outer;
494- auto treeAndDomain = makeScheduleTreeHelper (
450+ auto schedule = makeScheduleTreeHelper (
495451 s,
496452 isl::set::universe (paramSpace),
497453 outer,
@@ -501,16 +457,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
501457 &result.statements ,
502458 &result.iterators );
503459
504- // TODO: This fails if the stmt is just a Provide node, I'm not sure
505- // what the schedule tree should look like in that case.
506- CHECK (treeAndDomain.tree );
507-
508- // Add the outermost domain node
509- result.tree = ScheduleTree::makeDomain (
510- treeAndDomain.domain , std::move (treeAndDomain.tree ));
511-
512- // Check we have obeyed the ISL invariants
513- checkValidIslSchedule (result.tree .get ());
460+ result.tree = fromIslSchedule (schedule);
514461
515462 return result;
516463}
0 commit comments