@@ -114,3 +114,44 @@ int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
114114 mlir::emitError (loc, " unsupported type" );
115115 return 0 ;
116116}
117+
118+ mlir::Value cuf::computeElementCount (mlir::PatternRewriter &rewriter,
119+ mlir::Location loc,
120+ mlir::Value shapeOperand,
121+ mlir::Type seqType,
122+ mlir::Type targetType) {
123+ if (shapeOperand) {
124+ // Dynamic extent - extract from shape operand
125+ llvm::SmallVector<mlir::Value> extents;
126+ if (auto shapeOp =
127+ mlir::dyn_cast<fir::ShapeOp>(shapeOperand.getDefiningOp ())) {
128+ extents = shapeOp.getExtents ();
129+ } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
130+ shapeOperand.getDefiningOp ())) {
131+ for (auto i : llvm::enumerate (shapeShiftOp.getPairs ()))
132+ if (i.index () & 1 )
133+ extents.push_back (i.value ());
134+ }
135+
136+ if (extents.empty ())
137+ return mlir::Value ();
138+
139+ // Compute total element count by multiplying all dimensions
140+ mlir::Value count =
141+ fir::ConvertOp::create (rewriter, loc, targetType, extents[0 ]);
142+ for (unsigned i = 1 ; i < extents.size (); ++i) {
143+ auto operand =
144+ fir::ConvertOp::create (rewriter, loc, targetType, extents[i]);
145+ count = mlir::arith::MulIOp::create (rewriter, loc, count, operand);
146+ }
147+ return count;
148+ } else {
149+ // Static extent - use constant array size
150+ if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(seqType)) {
151+ mlir::IntegerAttr attr =
152+ rewriter.getIntegerAttr (targetType, seqTy.getConstantArraySize ());
153+ return mlir::arith::ConstantOp::create (rewriter, loc, targetType, attr);
154+ }
155+ }
156+ return mlir::Value ();
157+ }
0 commit comments