@@ -3108,10 +3108,10 @@ class ArrayExprLowering {
31083108 void lowerArrayAssignment (const TL &lhs, const TR &rhs) {
31093109 auto loc = getLoc ();
31103110 // / Here the target subspace is not necessarily contiguous. The ArrayUpdate
3111- // / continuation is implicitly returned in `ccDest ` and the ArrayLoad in
3112- // / `destination`.
3111+ // / continuation is implicitly returned in `ccStoreToDest ` and the ArrayLoad
3112+ // / in `destination`.
31133113 PushSemantics (ConstituentSemantics::ProjectedCopyInCopyOut);
3114- ccDest = genarr (lhs);
3114+ ccStoreToDest = genarr (lhs);
31153115 determineShapeOfDest (lhs);
31163116 semant = ConstituentSemantics::RefTransparent;
31173117 auto exv = lowerArrayExpression (rhs);
@@ -3143,7 +3143,7 @@ class ArrayExprLowering {
31433143 newIters.prependIndexValue (i);
31443144 return newIters;
31453145 };
3146- ccDest = [=](IterSpace iters) { return lambda (pc (iters)); };
3146+ ccStoreToDest = [=](IterSpace iters) { return lambda (pc (iters)); };
31473147 destShape.assign (extents.begin (), extents.end ());
31483148 semant = ConstituentSemantics::RefTransparent;
31493149 auto exv = lowerArrayExpression (rhs);
@@ -3246,7 +3246,8 @@ class ArrayExprLowering {
32463246 destShape, lengthParams);
32473247 // Create ArrayLoad for the mutable box and save it into `destination`.
32483248 PushSemantics (ConstituentSemantics::ProjectedCopyInCopyOut);
3249- ccDest = genarr (fir::factory::genMutableBoxRead (builder, loc, mutableBox));
3249+ ccStoreToDest =
3250+ genarr (fir::factory::genMutableBoxRead (builder, loc, mutableBox));
32503251 // If the rhs is scalar, get shape from the allocatable ArrayLoad.
32513252 if (destShape.empty ())
32523253 destShape = getShape (destination);
@@ -3310,7 +3311,7 @@ class ArrayExprLowering {
33103311
33113312 // / Entry point into lowering an expression with rank. This entry point is for
33123313 // / lowering a rhs expression, for example. (RefTransparent semantics.)
3313- static ExtValue lowerSomeNewArrayExpression (
3314+ static ExtValue lowerNewArrayExpression (
33143315 Fortran::lower::AbstractConverter &converter,
33153316 Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
33163317 const std::optional<Fortran::evaluate::Shape> &shape,
@@ -3331,7 +3332,7 @@ class ArrayExprLowering {
33313332 fir::dyn_cast_ptrEleTy (tempRes.getType ()).cast <fir::SequenceType>();
33323333 if (auto charTy =
33333334 arrTy.getEleTy ().template dyn_cast <fir::CharacterType>()) {
3334- if (charTy. getLen () <= 0 )
3335+ if (fir::characterWithDynamicLen (charTy) )
33353336 TODO (loc, " CHARACTER does not have constant LEN" );
33363337 auto len = builder.createIntegerConstant (
33373338 loc, builder.getCharacterLengthType (), charTy.getLen ());
@@ -3340,6 +3341,99 @@ class ArrayExprLowering {
33403341 return fir::ArrayBoxValue (tempRes, dest.getExtents ());
33413342 }
33423343
3344+ static ExtValue lowerLazyArrayExpression (
3345+ Fortran::lower::AbstractConverter &converter,
3346+ Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3347+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3348+ mlir::Value var) {
3349+ ArrayExprLowering ael{converter, stmtCtx, symMap};
3350+ return ael.lowerLazyArrayExpression (expr, var);
3351+ }
3352+
3353+ ExtValue lowerLazyArrayExpression (
3354+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3355+ mlir::Value var) {
3356+ auto loc = getLoc ();
3357+ // Once the loop extents have been computed, which may require being inside
3358+ // some explicit loops, lazily allocate the expression on the heap.
3359+ ccPrelude = [=](llvm::ArrayRef<mlir::Value> shape) -> mlir::Value {
3360+ auto load = builder.create <fir::LoadOp>(loc, var);
3361+ auto eleTy = fir::unwrapRefType (load.getType ());
3362+ auto unknown = fir::SequenceType::getUnknownExtent ();
3363+ fir::SequenceType::Shape extents (shape.size (), unknown);
3364+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3365+ auto toTy = fir::HeapType::get (seqTy);
3366+ auto castTo = builder.createConvert (loc, toTy, load);
3367+ auto cmp = builder.genIsNull (loc, castTo);
3368+ auto ifOp = builder.create <fir::IfOp>(loc, cmp, /* withElseRegion=*/ false );
3369+ auto insPt = builder.saveInsertionPoint ();
3370+ builder.setInsertionPointToStart (&ifOp.thenRegion ().front ());
3371+ auto mem = builder.create <fir::AllocMemOp>(loc, seqTy, " .lazy.mask" ,
3372+ llvm::None, shape);
3373+ auto uncast = builder.createConvert (loc, load.getType (), mem);
3374+ builder.create <fir::StoreOp>(loc, uncast, var);
3375+ builder.restoreInsertionPoint (insPt);
3376+ return mem;
3377+ };
3378+ // Create a dummy array_load before the loop. We're storing to a lazy
3379+ // temporary, so there will be no conflict and no copy-in.
3380+ ccLoadDest = [=](llvm::ArrayRef<mlir::Value> shape) -> fir::ArrayLoadOp {
3381+ auto load = builder.create <fir::LoadOp>(loc, var);
3382+ auto eleTy = fir::unwrapRefType (load.getType ());
3383+ auto unknown = fir::SequenceType::getUnknownExtent ();
3384+ fir::SequenceType::Shape extents (shape.size (), unknown);
3385+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3386+ auto toTy = fir::HeapType::get (seqTy);
3387+ auto castTo = builder.createConvert (loc, toTy, load);
3388+ auto shapeOp = builder.consShape (loc, shape);
3389+ return builder.create <fir::ArrayLoadOp>(
3390+ loc, seqTy, castTo, shapeOp, /* slice=*/ mlir::Value{}, llvm::None);
3391+ };
3392+ // Custom lowering of the element store to deal with the extra indirection
3393+ // to the lazy allocated buffer.
3394+ ccStoreToDest = [=](IterSpace iters) {
3395+ auto load = builder.create <fir::LoadOp>(loc, var);
3396+ auto eleTy = fir::unwrapRefType (load.getType ());
3397+ auto unknown = fir::SequenceType::getUnknownExtent ();
3398+ fir::SequenceType::Shape extents (iters.iterVec ().size (), unknown);
3399+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3400+ auto toTy = fir::HeapType::get (seqTy);
3401+ auto castTo = builder.createConvert (loc, toTy, load);
3402+ auto shape = builder.consShape (loc, genIterationShape ());
3403+ auto indices = fir::factory::originateIndices (
3404+ loc, builder, castTo.getType (), shape, iters.iterVec ());
3405+ auto eleAddr = builder.create <fir::ArrayCoorOp>(
3406+ loc, builder.getRefType (eleTy), castTo, shape,
3407+ /* slice=*/ mlir::Value{}, indices, destination.typeparams ());
3408+ auto eleVal = builder.createConvert (loc, eleTy, iters.getElement ());
3409+ builder.create <fir::StoreOp>(loc, eleVal, eleAddr);
3410+ return iters.innerArgument ();
3411+ };
3412+ auto loopRes = lowerArrayExpression (expr);
3413+ auto load = builder.create <fir::LoadOp>(loc, var);
3414+ auto eleTy = fir::unwrapRefType (load.getType ());
3415+ auto unknown = fir::SequenceType::getUnknownExtent ();
3416+ fir::SequenceType::Shape extents (genIterationShape ().size (), unknown);
3417+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3418+ auto toTy = fir::HeapType::get (seqTy);
3419+ auto tempRes = builder.createConvert (loc, toTy, load);
3420+ builder.create <fir::ArrayMergeStoreOp>(
3421+ loc, destination, fir::getBase (loopRes), tempRes, destination.slice (),
3422+ destination.typeparams ());
3423+ auto tempTy = fir::dyn_cast_ptrEleTy (tempRes.getType ());
3424+ assert (tempTy && tempTy.isa <fir::SequenceType>() &&
3425+ " must be a reference to an array" );
3426+ auto ety = fir::unwrapSequenceType (tempTy);
3427+ if (auto charTy = ety.dyn_cast <fir::CharacterType>()) {
3428+ if (fir::characterWithDynamicLen (charTy))
3429+ TODO (loc, " CHARACTER does not have constant LEN" );
3430+ auto len = builder.createIntegerConstant (
3431+ loc, builder.getCharacterLengthType (), charTy.getLen ());
3432+ return fir::CharArrayBoxValue (tempRes, len, destination.getExtents ());
3433+ }
3434+ return fir::ArrayBoxValue (tempRes, destination.getExtents ());
3435+ }
3436+
33433437 void determineShapeOfDest (const fir::ExtendedValue &lhs) {
33443438 destShape = fir::factory::getExtents (builder, getLoc (), lhs);
33453439 }
@@ -3416,9 +3510,9 @@ class ArrayExprLowering {
34163510 auto innerArg = iterSpace.innerArgument ();
34173511 auto exv = f (iterSpace);
34183512 mlir::Value upd;
3419- if (ccDest .hasValue ()) {
3513+ if (ccStoreToDest .hasValue ()) {
34203514 iterSpace.setElement (std::move (exv));
3421- upd = fir::getBase (ccDest .getValue ()(iterSpace));
3515+ upd = fir::getBase (ccStoreToDest .getValue ()(iterSpace));
34223516 } else {
34233517 auto resTy = adjustedArrayElementType (innerArg.getType ());
34243518 auto element = adjustedArrayElement (loc, builder, fir::getBase (exv),
@@ -3509,6 +3603,14 @@ class ArrayExprLowering {
35093603 // Mask expressions are array expressions too.
35103604 for (const auto *e : implicitSpace->getExprs ())
35113605 if (e && !implicitSpace->isLowered (e)) {
3606+ if (auto var = implicitSpace->lookupVariable (e)) {
3607+ // Allocate the mask buffer lazily.
3608+ auto tmp = Fortran::lower::createLazyArrayTempValue (
3609+ converter, *e, var, symMap, stmtCtx);
3610+ auto shape = builder.createShape (loc, tmp);
3611+ implicitSpace->bind (e, fir::getBase (tmp), shape);
3612+ continue ;
3613+ }
35123614 auto optShape =
35133615 Fortran::evaluate::GetShape (converter.getFoldingContext (), *e);
35143616 auto tmp = Fortran::lower::createSomeArrayTempValue (
@@ -3557,6 +3659,12 @@ class ArrayExprLowering {
35573659 const auto loopDepth = loopUppers.size ();
35583660 llvm::SmallVector<mlir::Value> ivars;
35593661 if (loopDepth > 0 ) {
3662+ // Generate the lazy mask allocation, if one was given.
3663+ if (ccPrelude.hasValue ()) {
3664+ [[maybe_unused]] auto allocMem = ccPrelude.getValue ()(shape);
3665+ assert (allocMem && " mask buffer allocation failure" );
3666+ }
3667+
35603668 auto *startBlock = builder.getBlock ();
35613669 for (auto i : llvm::enumerate (llvm::reverse (loopUppers))) {
35623670 if (i.index () > 0 ) {
@@ -3600,7 +3708,7 @@ class ArrayExprLowering {
36003708 // explicit masks, which are interleaved, these mask expression appear in
36013709 // the innermost loop.
36023710 if (implicitSpaceHasMasks ()) {
3603- auto prependAsNeeded = [&](auto &&indices) {
3711+ auto appendAsNeeded = [&](auto &&indices) {
36043712 llvm::SmallVector<mlir::Value> result;
36053713 result.append (indices.begin (), indices.end ());
36063714 return result;
@@ -3614,7 +3722,7 @@ class ArrayExprLowering {
36143722 auto eleRefTy = builder.getRefType (eleTy);
36153723 auto i1Ty = builder.getI1Type ();
36163724 // Adjust indices for any shift of the origin of the array.
3617- auto indexes = prependAsNeeded (fir::factory::originateIndices (
3725+ auto indexes = appendAsNeeded (fir::factory::originateIndices (
36183726 loc, builder, tmp.getType (), shape, iters.iterVec ()));
36193727 auto addr = builder.create <fir::ArrayCoorOp>(
36203728 loc, eleRefTy, tmp, shape, /* slice=*/ mlir::Value{}, indexes,
@@ -3664,6 +3772,8 @@ class ArrayExprLowering {
36643772 fir::ArrayLoadOp
36653773 createAndLoadSomeArrayTemp (mlir::Type type,
36663774 llvm::ArrayRef<mlir::Value> shape) {
3775+ if (ccLoadDest.hasValue ())
3776+ return ccLoadDest.getValue ()(shape);
36673777 auto seqTy = type.dyn_cast <fir::SequenceType>();
36683778 assert (seqTy && " must be an array" );
36693779 auto loc = getLoc ();
@@ -4613,7 +4723,7 @@ class ArrayExprLowering {
46134723 auto loc = getLoc ();
46144724 auto memref = fir::getBase (extMemref);
46154725 auto arrTy = fir::dyn_cast_ptrOrBoxEleTy (memref.getType ());
4616- assert (arrTy.isa <fir::SequenceType>());
4726+ assert (arrTy.isa <fir::SequenceType>() && " memory ref must be an array " );
46174727 auto shape = builder.createShape (loc, extMemref);
46184728 mlir::Value slice;
46194729 if (inSlice) {
@@ -4898,7 +5008,7 @@ class ArrayExprLowering {
48985008 if (isArray (x)) {
48995009 auto e = toEvExpr (x);
49005010 auto sh = Fortran::evaluate::GetShape (converter.getFoldingContext (), e);
4901- return {lowerSomeNewArrayExpression (converter, symMap, stmtCtx, sh, e),
5011+ return {lowerNewArrayExpression (converter, symMap, stmtCtx, sh, e),
49025012 /* needCopy=*/ true };
49035013 }
49045014 return {asScalar (x), /* needCopy=*/ true };
@@ -5429,7 +5539,11 @@ class ArrayExprLowering {
54295539 Fortran::lower::StatementContext &stmtCtx;
54305540 Fortran::lower::SymMap &symMap;
54315541 // / The continuation to generate code to update the destination.
5432- llvm::Optional<CC> ccDest;
5542+ llvm::Optional<CC> ccStoreToDest;
5543+ llvm::Optional<std::function<mlir::Value(llvm::ArrayRef<mlir::Value>)>>
5544+ ccPrelude;
5545+ llvm::Optional<std::function<fir::ArrayLoadOp(llvm::ArrayRef<mlir::Value>)>>
5546+ ccLoadDest;
54335547 // / The destination is the loaded array into which the results will be
54345548 // / merged.
54355549 fir::ArrayLoadOp destination;
@@ -5539,8 +5653,18 @@ fir::ExtendedValue Fortran::lower::createSomeArrayTempValue(
55395653 const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
55405654 Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
55415655 LLVM_DEBUG (expr.AsFortran (llvm::dbgs () << " array value: " ) << ' \n ' );
5542- return ArrayExprLowering::lowerSomeNewArrayExpression (converter, symMap,
5543- stmtCtx, shape, expr);
5656+ return ArrayExprLowering::lowerNewArrayExpression (converter, symMap, stmtCtx,
5657+ shape, expr);
5658+ }
5659+
5660+ fir::ExtendedValue Fortran::lower::createLazyArrayTempValue (
5661+ Fortran::lower::AbstractConverter &converter,
5662+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
5663+ mlir::Value var, Fortran::lower::SymMap &symMap,
5664+ Fortran::lower::StatementContext &stmtCtx) {
5665+ LLVM_DEBUG (expr.AsFortran (llvm::dbgs () << " array value: " ) << ' \n ' );
5666+ return ArrayExprLowering::lowerLazyArrayExpression (converter, symMap, stmtCtx,
5667+ expr, var);
55445668}
55455669
55465670fir::ExtendedValue Fortran::lower::createSomeArrayBox (
@@ -5637,6 +5761,9 @@ void Fortran::lower::createArrayMergeStores(
56375761 builder.create <fir::ArrayMergeStoreOp>(
56385762 loc, load, i.value (), load.memref (), load.slice (), load.typeparams ());
56395763 }
5764+ // Cleanup any residual mask buffers.
5765+ esp.outermostContext ().finalize ();
5766+ esp.outermostContext ().reset ();
56405767 }
56415768 esp.outerLoopStack .pop_back ();
56425769 esp.innerArgsStack .pop_back ();
0 commit comments