@@ -3108,10 +3108,10 @@ class ArrayExprLowering {
31083108 void lowerArrayAssignment (const TL &lhs, const TR &rhs) {
31093109 auto loc = getLoc ();
31103110 // / Here the target subspace is not necessarily contiguous. The ArrayUpdate
3111- // / continuation is implicitly returned in `ccDest ` and the ArrayLoad in
3112- // / `destination`.
3111+ // / continuation is implicitly returned in `ccStoreToDest ` and the ArrayLoad
3112+ // / in `destination`.
31133113 PushSemantics (ConstituentSemantics::ProjectedCopyInCopyOut);
3114- ccDest = genarr (lhs);
3114+ ccStoreToDest = genarr (lhs);
31153115 determineShapeOfDest (lhs);
31163116 semant = ConstituentSemantics::RefTransparent;
31173117 auto exv = lowerArrayExpression (rhs);
@@ -3143,7 +3143,7 @@ class ArrayExprLowering {
31433143 newIters.prependIndexValue (i);
31443144 return newIters;
31453145 };
3146- ccDest = [=](IterSpace iters) { return lambda (pc (iters)); };
3146+ ccStoreToDest = [=](IterSpace iters) { return lambda (pc (iters)); };
31473147 destShape.assign (extents.begin (), extents.end ());
31483148 semant = ConstituentSemantics::RefTransparent;
31493149 auto exv = lowerArrayExpression (rhs);
@@ -3246,7 +3246,8 @@ class ArrayExprLowering {
32463246 destShape, lengthParams);
32473247 // Create ArrayLoad for the mutable box and save it into `destination`.
32483248 PushSemantics (ConstituentSemantics::ProjectedCopyInCopyOut);
3249- ccDest = genarr (fir::factory::genMutableBoxRead (builder, loc, mutableBox));
3249+ ccStoreToDest =
3250+ genarr (fir::factory::genMutableBoxRead (builder, loc, mutableBox));
32503251 // If the rhs is scalar, get shape from the allocatable ArrayLoad.
32513252 if (destShape.empty ())
32523253 destShape = getShape (destination);
@@ -3310,7 +3311,7 @@ class ArrayExprLowering {
33103311
33113312 // / Entry point into lowering an expression with rank. This entry point is for
33123313 // / lowering a rhs expression, for example. (RefTransparent semantics.)
3313- static ExtValue lowerSomeNewArrayExpression (
3314+ static ExtValue lowerNewArrayExpression (
33143315 Fortran::lower::AbstractConverter &converter,
33153316 Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
33163317 const std::optional<Fortran::evaluate::Shape> &shape,
@@ -3331,7 +3332,7 @@ class ArrayExprLowering {
33313332 fir::dyn_cast_ptrEleTy (tempRes.getType ()).cast <fir::SequenceType>();
33323333 if (auto charTy =
33333334 arrTy.getEleTy ().template dyn_cast <fir::CharacterType>()) {
3334- if (charTy. getLen () <= 0 )
3335+ if (fir::characterWithDynamicLen (charTy) )
33353336 TODO (loc, " CHARACTER does not have constant LEN" );
33363337 auto len = builder.createIntegerConstant (
33373338 loc, builder.getCharacterLengthType (), charTy.getLen ());
@@ -3340,6 +3341,98 @@ class ArrayExprLowering {
33403341 return fir::ArrayBoxValue (tempRes, dest.getExtents ());
33413342 }
33423343
3344+ static ExtValue lowerLazyArrayExpression (
3345+ Fortran::lower::AbstractConverter &converter,
3346+ Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3347+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3348+ mlir::Value var) {
3349+ ArrayExprLowering ael{converter, stmtCtx, symMap};
3350+ return ael.lowerLazyArrayExpression (expr, var);
3351+ }
3352+
3353+ ExtValue lowerLazyArrayExpression (
3354+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3355+ mlir::Value var) {
3356+ auto loc = getLoc ();
3357+ // Once the loop extents have been computed, which may require being inside
3358+ // some explicit loops, lazily allocate the expression on the heap.
3359+ ccPrelude = [=](llvm::ArrayRef<mlir::Value> shape) {
3360+ auto load = builder.create <fir::LoadOp>(loc, var);
3361+ auto eleTy = fir::unwrapRefType (load.getType ());
3362+ auto unknown = fir::SequenceType::getUnknownExtent ();
3363+ fir::SequenceType::Shape extents (shape.size (), unknown);
3364+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3365+ auto toTy = fir::HeapType::get (seqTy);
3366+ auto castTo = builder.createConvert (loc, toTy, load);
3367+ auto cmp = builder.genIsNull (loc, castTo);
3368+ builder.genIfThen (loc, cmp)
3369+ .genThen ([&]() {
3370+ auto mem = builder.create <fir::AllocMemOp>(loc, seqTy, " .lazy.mask" ,
3371+ llvm::None, shape);
3372+ auto uncast = builder.createConvert (loc, load.getType (), mem);
3373+ builder.create <fir::StoreOp>(loc, uncast, var);
3374+ })
3375+ .end ();
3376+ };
3377+ // Create a dummy array_load before the loop. We're storing to a lazy
3378+ // temporary, so there will be no conflict and no copy-in.
3379+ ccLoadDest = [=](llvm::ArrayRef<mlir::Value> shape) -> fir::ArrayLoadOp {
3380+ auto load = builder.create <fir::LoadOp>(loc, var);
3381+ auto eleTy = fir::unwrapRefType (load.getType ());
3382+ auto unknown = fir::SequenceType::getUnknownExtent ();
3383+ fir::SequenceType::Shape extents (shape.size (), unknown);
3384+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3385+ auto toTy = fir::HeapType::get (seqTy);
3386+ auto castTo = builder.createConvert (loc, toTy, load);
3387+ auto shapeOp = builder.consShape (loc, shape);
3388+ return builder.create <fir::ArrayLoadOp>(
3389+ loc, seqTy, castTo, shapeOp, /* slice=*/ mlir::Value{}, llvm::None);
3390+ };
3391+ // Custom lowering of the element store to deal with the extra indirection
3392+ // to the lazy allocated buffer.
3393+ ccStoreToDest = [=](IterSpace iters) {
3394+ auto load = builder.create <fir::LoadOp>(loc, var);
3395+ auto eleTy = fir::unwrapRefType (load.getType ());
3396+ auto unknown = fir::SequenceType::getUnknownExtent ();
3397+ fir::SequenceType::Shape extents (iters.iterVec ().size (), unknown);
3398+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3399+ auto toTy = fir::HeapType::get (seqTy);
3400+ auto castTo = builder.createConvert (loc, toTy, load);
3401+ auto shape = builder.consShape (loc, genIterationShape ());
3402+ auto indices = fir::factory::originateIndices (
3403+ loc, builder, castTo.getType (), shape, iters.iterVec ());
3404+ auto eleAddr = builder.create <fir::ArrayCoorOp>(
3405+ loc, builder.getRefType (eleTy), castTo, shape,
3406+ /* slice=*/ mlir::Value{}, indices, destination.typeparams ());
3407+ auto eleVal = builder.createConvert (loc, eleTy, iters.getElement ());
3408+ builder.create <fir::StoreOp>(loc, eleVal, eleAddr);
3409+ return iters.innerArgument ();
3410+ };
3411+ auto loopRes = lowerArrayExpression (expr);
3412+ auto load = builder.create <fir::LoadOp>(loc, var);
3413+ auto eleTy = fir::unwrapRefType (load.getType ());
3414+ auto unknown = fir::SequenceType::getUnknownExtent ();
3415+ fir::SequenceType::Shape extents (genIterationShape ().size (), unknown);
3416+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3417+ auto toTy = fir::HeapType::get (seqTy);
3418+ auto tempRes = builder.createConvert (loc, toTy, load);
3419+ builder.create <fir::ArrayMergeStoreOp>(
3420+ loc, destination, fir::getBase (loopRes), tempRes, destination.slice (),
3421+ destination.typeparams ());
3422+ auto tempTy = fir::dyn_cast_ptrEleTy (tempRes.getType ());
3423+ assert (tempTy && tempTy.isa <fir::SequenceType>() &&
3424+ " must be a reference to an array" );
3425+ auto ety = fir::unwrapSequenceType (tempTy);
3426+ if (auto charTy = ety.dyn_cast <fir::CharacterType>()) {
3427+ if (fir::characterWithDynamicLen (charTy))
3428+ TODO (loc, " CHARACTER does not have constant LEN" );
3429+ auto len = builder.createIntegerConstant (
3430+ loc, builder.getCharacterLengthType (), charTy.getLen ());
3431+ return fir::CharArrayBoxValue (tempRes, len, destination.getExtents ());
3432+ }
3433+ return fir::ArrayBoxValue (tempRes, destination.getExtents ());
3434+ }
3435+
33433436 void determineShapeOfDest (const fir::ExtendedValue &lhs) {
33443437 destShape = fir::factory::getExtents (builder, getLoc (), lhs);
33453438 }
@@ -3416,9 +3509,9 @@ class ArrayExprLowering {
34163509 auto innerArg = iterSpace.innerArgument ();
34173510 auto exv = f (iterSpace);
34183511 mlir::Value upd;
3419- if (ccDest .hasValue ()) {
3512+ if (ccStoreToDest .hasValue ()) {
34203513 iterSpace.setElement (std::move (exv));
3421- upd = fir::getBase (ccDest .getValue ()(iterSpace));
3514+ upd = fir::getBase (ccStoreToDest .getValue ()(iterSpace));
34223515 } else {
34233516 auto resTy = adjustedArrayElementType (innerArg.getType ());
34243517 auto element = adjustedArrayElement (loc, builder, fir::getBase (exv),
@@ -3509,6 +3602,14 @@ class ArrayExprLowering {
35093602 // Mask expressions are array expressions too.
35103603 for (const auto *e : implicitSpace->getExprs ())
35113604 if (e && !implicitSpace->isLowered (e)) {
3605+ if (auto var = implicitSpace->lookupMaskVariable (e)) {
3606+ // Allocate the mask buffer lazily.
3607+ auto tmp = Fortran::lower::createLazyArrayTempValue (
3608+ converter, *e, var, symMap, stmtCtx);
3609+ auto shape = builder.createShape (loc, tmp);
3610+ implicitSpace->bind (e, fir::getBase (tmp), shape);
3611+ continue ;
3612+ }
35123613 auto optShape =
35133614 Fortran::evaluate::GetShape (converter.getFoldingContext (), *e);
35143615 auto tmp = Fortran::lower::createSomeArrayTempValue (
@@ -3557,6 +3658,10 @@ class ArrayExprLowering {
35573658 const auto loopDepth = loopUppers.size ();
35583659 llvm::SmallVector<mlir::Value> ivars;
35593660 if (loopDepth > 0 ) {
3661+ // Generate the lazy mask allocation, if one was given.
3662+ if (ccPrelude.hasValue ())
3663+ ccPrelude.getValue ()(shape);
3664+
35603665 auto *startBlock = builder.getBlock ();
35613666 for (auto i : llvm::enumerate (llvm::reverse (loopUppers))) {
35623667 if (i.index () > 0 ) {
@@ -3600,7 +3705,7 @@ class ArrayExprLowering {
36003705 // explicit masks, which are interleaved, these mask expression appear in
36013706 // the innermost loop.
36023707 if (implicitSpaceHasMasks ()) {
3603- auto prependAsNeeded = [&](auto &&indices) {
3708+ auto appendAsNeeded = [&](auto &&indices) {
36043709 llvm::SmallVector<mlir::Value> result;
36053710 result.append (indices.begin (), indices.end ());
36063711 return result;
@@ -3614,7 +3719,7 @@ class ArrayExprLowering {
36143719 auto eleRefTy = builder.getRefType (eleTy);
36153720 auto i1Ty = builder.getI1Type ();
36163721 // Adjust indices for any shift of the origin of the array.
3617- auto indexes = prependAsNeeded (fir::factory::originateIndices (
3722+ auto indexes = appendAsNeeded (fir::factory::originateIndices (
36183723 loc, builder, tmp.getType (), shape, iters.iterVec ()));
36193724 auto addr = builder.create <fir::ArrayCoorOp>(
36203725 loc, eleRefTy, tmp, shape, /* slice=*/ mlir::Value{}, indexes,
@@ -3664,6 +3769,8 @@ class ArrayExprLowering {
36643769 fir::ArrayLoadOp
36653770 createAndLoadSomeArrayTemp (mlir::Type type,
36663771 llvm::ArrayRef<mlir::Value> shape) {
3772+ if (ccLoadDest.hasValue ())
3773+ return ccLoadDest.getValue ()(shape);
36673774 auto seqTy = type.dyn_cast <fir::SequenceType>();
36683775 assert (seqTy && " must be an array" );
36693776 auto loc = getLoc ();
@@ -4613,7 +4720,7 @@ class ArrayExprLowering {
46134720 auto loc = getLoc ();
46144721 auto memref = fir::getBase (extMemref);
46154722 auto arrTy = fir::dyn_cast_ptrOrBoxEleTy (memref.getType ());
4616- assert (arrTy.isa <fir::SequenceType>());
4723+ assert (arrTy.isa <fir::SequenceType>() && " memory ref must be an array " );
46174724 auto shape = builder.createShape (loc, extMemref);
46184725 mlir::Value slice;
46194726 if (inSlice) {
@@ -4898,7 +5005,7 @@ class ArrayExprLowering {
48985005 if (isArray (x)) {
48995006 auto e = toEvExpr (x);
49005007 auto sh = Fortran::evaluate::GetShape (converter.getFoldingContext (), e);
4901- return {lowerSomeNewArrayExpression (converter, symMap, stmtCtx, sh, e),
5008+ return {lowerNewArrayExpression (converter, symMap, stmtCtx, sh, e),
49025009 /* needCopy=*/ true };
49035010 }
49045011 return {asScalar (x), /* needCopy=*/ true };
@@ -5429,7 +5536,10 @@ class ArrayExprLowering {
54295536 Fortran::lower::StatementContext &stmtCtx;
54305537 Fortran::lower::SymMap &symMap;
54315538 // / The continuation to generate code to update the destination.
5432- llvm::Optional<CC> ccDest;
5539+ llvm::Optional<CC> ccStoreToDest;
5540+ llvm::Optional<std::function<void (llvm::ArrayRef<mlir::Value>)>> ccPrelude;
5541+ llvm::Optional<std::function<fir::ArrayLoadOp(llvm::ArrayRef<mlir::Value>)>>
5542+ ccLoadDest;
54335543 // / The destination is the loaded array into which the results will be
54345544 // / merged.
54355545 fir::ArrayLoadOp destination;
@@ -5539,8 +5649,18 @@ fir::ExtendedValue Fortran::lower::createSomeArrayTempValue(
55395649 const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
55405650 Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
55415651 LLVM_DEBUG (expr.AsFortran (llvm::dbgs () << " array value: " ) << ' \n ' );
5542- return ArrayExprLowering::lowerSomeNewArrayExpression (converter, symMap,
5543- stmtCtx, shape, expr);
5652+ return ArrayExprLowering::lowerNewArrayExpression (converter, symMap, stmtCtx,
5653+ shape, expr);
5654+ }
5655+
5656+ fir::ExtendedValue Fortran::lower::createLazyArrayTempValue (
5657+ Fortran::lower::AbstractConverter &converter,
5658+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
5659+ mlir::Value var, Fortran::lower::SymMap &symMap,
5660+ Fortran::lower::StatementContext &stmtCtx) {
5661+ LLVM_DEBUG (expr.AsFortran (llvm::dbgs () << " array value: " ) << ' \n ' );
5662+ return ArrayExprLowering::lowerLazyArrayExpression (converter, symMap, stmtCtx,
5663+ expr, var);
55445664}
55455665
55465666fir::ExtendedValue Fortran::lower::createSomeArrayBox (
@@ -5637,6 +5757,9 @@ void Fortran::lower::createArrayMergeStores(
56375757 builder.create <fir::ArrayMergeStoreOp>(
56385758 loc, load, i.value (), load.memref (), load.slice (), load.typeparams ());
56395759 }
5760+ // Cleanup any residual mask buffers.
5761+ esp.outermostContext ().finalize ();
5762+ esp.outermostContext ().reset ();
56405763 }
56415764 esp.outerLoopStack .pop_back ();
56425765 esp.innerArgsStack .pop_back ();
0 commit comments