diff --git a/MFE_time_size.ipynb b/MFE_time_size.ipynb
new file mode 100644
index 0000000000..91982d13da
--- /dev/null
+++ b/MFE_time_size.ipynb
@@ -0,0 +1,637 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "f11a9f6b",
+ "metadata": {},
+ "source": [
+ "# Multi-Stage Minimum Failure Example (MFE) Time Stepping Simulation\n",
+ "\n",
+ "This notebook implements a multi-stage Runge-Kutta-like time integration scheme for solving a PDE system using Devito.\n",
+ "\n",
+ "## Overview\n",
+ "\n",
+ "To apply multi-stage methods to a PDE system, we first have to formulated as a first-order in time system:\n",
+ "\n",
+ "\n",
+ "$\\frac{d}{dt}\\boldsymbol{U}(\\boldsymbol{x},t)=\\boldsymbol{HU}(\\boldsymbol{x},t)+\\boldsymbol{f}(\\boldsymbol{x},t),\\hspace{3cm}(1)$\n",
+ "\n",
+ "where $\\boldsymbol{x}$ is a vector with the spatial variables (x in 1D, (x,y) in 2D, (x,y,z in 3D)), $\\boldsymbol{U}$ is a vector of real-valued functions, $\\boldsymbol{H}$ is a spatial operator that contains spatial derivatives and constant coefficients, and $\\boldsymbol{f}$ is a known vector of real-valued functions.\n",
+ "\n",
+ "Expanding $\\boldsymbol{f}(x,y,t)$ in its Taylor 's series, we can approximate the solution of system (1) by the matrix exponential (see [[1](#ref-higham2010)])\n",
+ "\n",
+ "$\\hat{\\boldsymbol{U}}(\\boldsymbol{x},t)\\approx[\\boldsymbol{I_p}\\;0]e^{t\\hat{\\boldsymbol{H}}}\\begin{bmatrix}\\boldsymbol{U}(\\boldsymbol{x},t_0)\\\\ \\boldsymbol{e_p}\\end{bmatrix},$\n",
+ "\n",
+ "where $\\boldsymbol{e_p}\\in\\mathbb{R}^p$ is the eigenvector with zero in all its entries exept the last one, that equals 1, $\\hat{\\boldsymbol{U}}$ is the vector function $\\boldsymbol{U}$ with $p$ extra zeros at the end, $\\boldsymbol{I_p}$ is the identity matrix of size $p$, and $\\hat{\\boldsymbol{H}}$ is the matrix\n",
+ "\n",
+ "$\\hat{\\boldsymbol{H}}=\\begin{bmatrix}\\boldsymbol{H} & \\boldsymbol{W}\\\\ 0 & \\boldsymbol{J_p}\\end{bmatrix}, \\hspace{5.265cm}(2)$\n",
+ "\n",
+ "where $\\boldsymbol{J_p}$ is the zero matrix with an upper diagonal of ones\n",
+ "\n",
+ "$\\boldsymbol{J_p}=\\begin{bmatrix}0 & 1 & 0 & 0 & \\dots & 0\\\\ 0 & 0 & 1 & 0 & \\dots & 0\\\\ &&\\ddots&&&\\vdots \n",
+ "\\\\ 0 & 0 & 0 & 0 & \\dots & 1 \\\\0 & 0 & 0 & 0 & \\dots & 0\\end{bmatrix},$\n",
+ "\n",
+ "and $\\boldsymbol{W}$ contains the information of the vector function $\\boldsymbol{f}(\\boldsymbol{x},t)$ derivatives\n",
+ "\n",
+ "$\\boldsymbol{W}=\\begin{bmatrix}\\frac{\\partial^{p-1}}{\\partial t^{p-1}}\\boldsymbol{f}(\\boldsymbol{x},t_0)\\bigg\\vert \\frac{\\partial^{p-2}}{\\partial t^{p-2}}\\boldsymbol{f}(\\boldsymbol{x},t_0)\\bigg\\vert \\dots \\bigg\\vert \\frac{\\partial}{\\partial t}\\boldsymbol{f}(\\boldsymbol{x},t_0)\\bigg\\vert \\boldsymbol{f}(\\boldsymbol{x},t_0)\\end{bmatrix}.$\n",
+ "\n",
+ "Then, we approximate the solution operator in (2) with the m-stage Runge-Kutta (RK) method from [[2](#ref-gottlieb2003)]\n",
+ "\n",
+ "\\begin{align*}\n",
+ " \\boldsymbol{k}_0&=\\boldsymbol{u}_n\\\\\n",
+ " \\boldsymbol{k}_i&=\\left(\\boldsymbol{I}_{n\\times n}+\\Delta t \\hat{\\boldsymbol{H}}\\right)\\boldsymbol{k}_{i-1},\\quad i=1\\dots m-1\\\\\n",
+ " \\boldsymbol{k}_m&=\\sum\\limits_{i=0}^{m-2}\\alpha_i\\boldsymbol{k}_i+\\alpha_{m-1}\\left(\\boldsymbol{I}_{n\\times n}+\\Delta t \\hat{\\boldsymbol{H}}\\right)\\boldsymbol{k}_{m-1}\\\\\n",
+ " \\boldsymbol{u}_{n+1}&=\\boldsymbol{k}_m, \n",
+ "\\end{align*}\n",
+ "\n",
+ "where $\\alpha_i$ are the coefficients of the Runge-Kutta and have a straightforward computation.\n",
+ "\n",
+ "So, for each time step we have to construct the matrix $\\hat{\\boldsymbol{H}}$, where we only the submatrix $\\boldsymbol{W}$ change, and apply the RK method to the vector $\\boldsymbol{u}_n=[\\boldsymbol{U}(\\boldsymbol{x},t_n)\\;\\; \\boldsymbol{e_p}]^T$.\n",
+ "\n",
+ "So, an outline of the implementation is:\n",
+ "\n",
+ "- **Environmental variables**: Set up the grid, and spatial and time variables\n",
+ "- **PDE System definition**: Define $\\boldsymbol{f}(\\boldsymbol{x},t)$ and the system of equations\n",
+ "- **Compute the derivatives of the source term**: symbolic computing of $\\boldsymbol{f}(\\boldsymbol{x},t)$ derivatives\n",
+ "- **Construct the operator $\\hat{\\boldsymbol{H}}$**: application the application of operator $\\hat{\\boldsymbol{H}}$\n",
+ "- **Implementation of the RK method**: define the required equations of the RK method to pass to Devito's operator. For this particular example, we'll use $m=3$ and $\\alpha_i=1,\\;\\forall i\\in\\{0,1,2\\}$.\n",
+ "- **Create and run Devito operator**: Executing the operator constructed with the RK equations"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "274432cf",
+ "metadata": {},
+ "source": [
+ "## 0. Import Required Libraries\n",
+ "\n",
+ "First, we import all necessary libraries including NumPy for numerical operations, SymPy for symbolic mathematics, and Devito components for finite difference operations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "930afc0c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import sympy as sym\n",
+ "from devito import (Grid, Function, TimeFunction,\n",
+ " Derivative, Operator, Eq)\n",
+ "from devito import configuration\n",
+ "from devito.symbolics import uxreplace"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3a3b1e0e",
+ "metadata": {},
+ "source": [
+ "## 1. Environmental variables\n",
+ "\n",
+ "Configure the simulation environment, including logging level, domain parameters, and computational grid setup."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "0d2a32c5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Grid created: (201, 201) points over domain (1, 1)\n",
+ "Spatial dimensions: x=x, y=y\n",
+ "Time dimension: t=time, dt=dt\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Configure Devito logging\n",
+ "configuration['log-level'] = 'DEBUG'\n",
+ "\n",
+ "# Simulation parameters\n",
+ "extent = (1, 1) # Physical domain size\n",
+ "shape = (201, 201) # Grid resolution\n",
+ "origin = (0, 0) # Domain origin\n",
+ "\n",
+ "# Create computational grid\n",
+ "grid = Grid(origin=origin, extent=extent, shape=shape, dtype=np.float64)\n",
+ "x, y = grid.dimensions\n",
+ "t, dt = grid.time_dim, grid.stepping_dim.spacing\n",
+ "\n",
+ "print(f\"Grid created: {shape} points over domain {extent}\")\n",
+ "print(f\"Spatial dimensions: x={x}, y={y}\")\n",
+ "print(f\"Time dimension: t={t}, dt={dt}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d1d7dfd2",
+ "metadata": {},
+ "source": [
+ "## 2. PDE System definition\n",
+ "\n",
+ "Create the TimeFunction objects for the wavefield variables (displacement u and velocity v) and define the source terms with both spatial and temporal characteristics."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "5c42f929",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Allocating host memory for src_spat(205, 205) [328 KB]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Wavefield functions created:\n",
+ " u_multi_stage: u_multi_stage(t, x, y)\n",
+ " v_multi_stage: v_multi_stage(t, x, y)\n",
+ "\n",
+ "Source spatial function: src_spat(x, y)\n",
+ "Source temporal function: exp(-100*(time - 0.01)**2)\n",
+ "\n",
+ "PDE system matrix H:\n",
+ " du_multi_stage/dt = v_multi_stage(t, x, y)\n",
+ " dv_multi_stage/dt = Derivative(u_multi_stage(t, x, y), (x, 2)) + Derivative(u_multi_stage(t, x, y), (y, 2))\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define wavefield unknowns: u (displacement) and v (velocity)\n",
+ "fun_labels = ['u_multi_stage', 'v_multi_stage']\n",
+ "U_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, dtype=np.float64) for name in fun_labels]\n",
+ "\n",
+ "print(\"Wavefield functions created:\")\n",
+ "for i, u in enumerate(U_multi_stage):\n",
+ " print(f\" {fun_labels[i]}: {u}\")\n",
+ "\n",
+ "# Source definition\n",
+ "src_spatial = Function(name=\"src_spat\", grid=grid, space_order=2, dtype=np.float64)\n",
+ "src_spatial.data[100, 100] = 1 # Point source at grid center\n",
+ "src_temporal = sym.exp(- 100 * (t - 0.01) ** 2) # Gaussian pulse\n",
+ "\n",
+ "print(f\"\\nSource spatial function: {src_spatial}\")\n",
+ "print(f\"Source temporal function: {src_temporal}\")\n",
+ "\n",
+ "# PDE right-hand side: du/dt = v, dv/dt = ∇²u\n",
+ "system_eqs_rhs = [U_multi_stage[1], # du/dt = v\n",
+ " Derivative(U_multi_stage[0], (x, 2), fd_order=2) +\n",
+ " Derivative(U_multi_stage[0], (y, 2), fd_order=2)] # dv/dt = ∇²u\n",
+ "\n",
+ "# Source coupling: [spatial, temporal, associated variable]\n",
+ "src = [[src_spatial, src_temporal, U_multi_stage[0]],\n",
+ " [src_spatial, src_temporal * 10, U_multi_stage[0]],\n",
+ " [src_spatial, src_temporal, U_multi_stage[1]]]\n",
+ "\n",
+ "print(f\"\\nPDE system matrix H:\")\n",
+ "for i, rhs in enumerate(system_eqs_rhs):\n",
+ " print(f\" d{fun_labels[i]}/dt = {rhs}\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "38f6faa0",
+ "metadata": {},
+ "source": [
+ "## 3. Compute the derivatives of the source term\n",
+ "\n",
+ "Implement the core helper function that compute time derivatives of source and calculate the source derivatives up to the specified degree (deg=3)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "046482b7",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " - source_derivatives(): Computes time derivatives of source wavelet\n",
+ "Source index mapping:\n",
+ " u_multi_stage(t, x, y) -> 0\n",
+ " v_multi_stage(t, x, y) -> 1\n",
+ "\n",
+ "Source indices: [0, 0, 1]\n",
+ "\n",
+ "Source derivatives computed up to degree 3\n",
+ "Number of derivative levels: 3\n",
+ "Sample derivative expressions:\n",
+ " Level 2: [(2.0 - 200*time)**2*exp(-100*(time - 0.01)**2) - 200*exp(-100*(time - 0.01)**2), 10*(2.0 - 200*time)**2*exp(-100*(time - 0.01)**2) - 2000*exp(-100*(time - 0.01)**2), (2.0 - 200*time)**2*exp(-100*(time - 0.01)**2) - 200*exp(-100*(time - 0.01)**2)]\n",
+ " Level 1: [(2.0 - 200*time)*exp(-100*(time - 0.01)**2), 10*(2.0 - 200*time)*exp(-100*(time - 0.01)**2), (2.0 - 200*time)*exp(-100*(time - 0.01)**2)]\n",
+ " Level 0: [exp(-100*(time - 0.01)**2), 10*exp(-100*(time - 0.01)**2), exp(-100*(time - 0.01)**2)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "def source_derivatives(deg, src, src_index, t):\n",
+ " \"\"\"\n",
+ " Compute time derivatives of the source up to given degree.\n",
+ " \n",
+ " Parameters:\n",
+ " -----------\n",
+ " deg : int\n",
+ " Degree of derivatives to compute\n",
+ " src : list\n",
+ " List of source terms\n",
+ " src_index : list\n",
+ " Indices for source terms\n",
+ " t : symbol\n",
+ " Time symbol\n",
+ " \n",
+ " Returns:\n",
+ " --------\n",
+ " f_deriv : list\n",
+ " List of derivative expressions\n",
+ " \"\"\"\n",
+ " f_deriv = [[src[i][1] for i in range(len(src))]]\n",
+ " \n",
+ " # Compute derivatives up to order p\n",
+ " for _ in range(deg - 1):\n",
+ " f_deriv.append([f_deriv[-1][i].diff(t) for i in range(len(src_index))])\n",
+ " \n",
+ " f_deriv.reverse()\n",
+ " return f_deriv\n",
+ "\n",
+ "print(\" - source_derivatives(): Computes time derivatives of source wavelet\")\n",
+ "\n",
+ "# Create mapping from wavefield variables to indices\n",
+ "src_index_map = {val: i for i, val in enumerate(U_multi_stage)}\n",
+ "print(\"Source index mapping:\")\n",
+ "for var, idx in src_index_map.items():\n",
+ " print(f\" {var} -> {idx}\")\n",
+ "\n",
+ "# Extract source indices based on associated variables\n",
+ "src_index = [src_index_map[val] for val in [src[i][2] for i in range(len(src))]]\n",
+ "print(f\"\\nSource indices: {src_index}\")\n",
+ "\n",
+ "# Degree of derivatives to compute\n",
+ "deg = 3\n",
+ "\n",
+ "# Compute source derivatives\n",
+ "src_deriv = source_derivatives(deg, src, src_index, t)\n",
+ "print(f\"\\nSource derivatives computed up to degree {deg}\")\n",
+ "print(f\"Number of derivative levels: {len(src_deriv)}\")\n",
+ "print(\"Sample derivative expressions:\")\n",
+ "for i, deriv in enumerate(src_deriv):\n",
+ " print(f\" Level {deg-1-i}: {deriv}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "60c485ae",
+ "metadata": {},
+ "source": [
+ "## 4.Construct the operator $\\hat{\\boldsymbol{H}}$\n",
+ "\n",
+ " Application of the spatial operator to the vector formed by [u e_p]^T."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "ccd0fe2a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " - source_inclusion(): Apply the spatial operator of the PDE system at RK stages\n"
+ ]
+ }
+ ],
+ "source": [
+ "# and handle source term inclusion in the PDE system at each Runge-Kutta stage\n",
+ "def source_inclusion(rhs, u, k, src_index, src_deriv, e_p, t, dt, n_eq):\n",
+ " \"\"\"\n",
+ " Add source terms to the PDE system at each Runge-Kutta stage.\n",
+ " \n",
+ " Parameters:\n",
+ " -----------\n",
+ " rhs : list\n",
+ " Right-hand side of PDE system without sources\n",
+ " u : list \n",
+ " Wavefield variables\n",
+ " k : list\n",
+ " Runge-Kutta stage variables\n",
+ " src_index : list\n",
+ " Source indices\n",
+ " src_deriv : list\n",
+ " Source derivatives\n",
+ " e_p : list\n",
+ " Expansion coefficients of the source term Taylor's series\n",
+ " t : symbol\n",
+ " Time symbol\n",
+ " dt : symbol\n",
+ " Time step symbol\n",
+ " n_eq : int\n",
+ " Number of equations\n",
+ " \n",
+ " Returns:\n",
+ " --------\n",
+ " src_lhs : list\n",
+ " Operator application to the vector [u, e_p]^T including source terms\n",
+ " e_p : list\n",
+ " Updated expansion coefficients of the source term Taylor's series\n",
+ " \"\"\"\n",
+ " # Replace wavefield variables with stage variables\n",
+ " src_lhs = [uxreplace(rhs[i], {u[m]: k[m] for m in range(n_eq)}) for i in range(n_eq)]\n",
+ " \n",
+ " p = len(src_deriv)\n",
+ " \n",
+ " # Add source contributions\n",
+ " for i in range(p):\n",
+ " if e_p[i] != 0:\n",
+ " for j in range(len(src_index)):\n",
+ " src_lhs[src_index[j]] += src[j][0] * src_deriv[i][j].subs({t: t * dt}) * e_p[i]\n",
+ " \n",
+ " # Update expansion coefficients of the source term Taylor's series\n",
+ " e_p = [e_p[i] + dt * e_p[i + 1] for i in range(p - 1)] + [e_p[-1]]\n",
+ " \n",
+ " return src_lhs, e_p\n",
+ "\n",
+ "print(\" - source_inclusion(): Apply the spatial operator of the PDE system at RK stages\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a462c0a9",
+ "metadata": {},
+ "source": [
+ "## 5. Implementation of the RK method\n",
+ "\n",
+ "Construct the multi-stage time integration scheme with initialization, multiple RK stages, and final update. This implements a toy example of a class of High-Order Runge-Kutta (HORK) methods, with proper source term integration."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "0543e6d9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Runge-Kutta temporary variables:\n",
+ " k: ['k0', 'k1']\n",
+ " k_old: ['k00', 'k01']\n",
+ "\n",
+ "Expansion coefficients initialized of the source term Taylor's series:\n",
+ " e_p = [0, 0, 1.0]\n",
+ " eta = 1\n",
+ "\n",
+ "SSPRK coefficients initialized:\n",
+ " alpha = [1, 1, 1, 1]\n",
+ "Building Runge-Kutta-like scheme...\n",
+ " Stage 0: Initialization\n",
+ " Stage 1: First RK stage\n",
+ " Stage 2: Second RK stage\n",
+ " Stage 3: Final RK stage and update\n",
+ "\n",
+ "Total equations created: 20\n",
+ "Scheme structure:\n",
+ " Stage 0: Eq(k0(x, y), u_multi_stage(t, x, y))\n",
+ " Stage 1: Eq(k1(x, y), v_multi_stage(t, x, y))\n",
+ " Stage 2: Eq(u_multi_stage(t + dt, x, y), u_multi_stage(t, x, y))\n",
+ " Stage 3: Eq(v_multi_stage(t + dt, x, y), v_multi_stage(t, x, y))\n",
+ " Stage 4: Eq(k00(x, y), k0(x, y))\n",
+ " Stage 5: Eq(k01(x, y), k1(x, y))\n",
+ " Stage 6: Eq(k0(x, y), dt*(k01(x, y) + 11.0*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2)) + k00(x, y))\n",
+ " Stage 7: Eq(k1(x, y), dt*(src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + Derivative(k00(x, y), (x, 2)) + Derivative(k00(x, y), (y, 2))) + k01(x, y))\n",
+ " Stage 8: Eq(u_multi_stage(t + dt, x, y), k0(x, y) + u_multi_stage(t + dt, x, y))\n",
+ " Stage 9: Eq(v_multi_stage(t + dt, x, y), k1(x, y) + v_multi_stage(t + dt, x, y))\n",
+ " Stage 10: Eq(k00(x, y), k0(x, y))\n",
+ " Stage 11: Eq(k01(x, y), k1(x, y))\n",
+ " Stage 12: Eq(k0(x, y), dt*(11.0*dt*(-200*time*dt + 2.0)*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + k01(x, y) + 11.0*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2)) + k00(x, y))\n",
+ " Stage 13: Eq(k1(x, y), dt*(1.0*dt*(-200*time*dt + 2.0)*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + Derivative(k00(x, y), (x, 2)) + Derivative(k00(x, y), (y, 2))) + k01(x, y))\n",
+ " Stage 14: Eq(k00(x, y), k0(x, y))\n",
+ " Stage 15: Eq(k01(x, y), k1(x, y))\n",
+ " Stage 16: Eq(k0(x, y), dt*(1.0*dt**2*((-200*time*dt + 2.0)**2*exp(-100*(time*dt - 0.01)**2) - 200*exp(-100*(time*dt - 0.01)**2))*src_spat(x, y) + 1.0*dt**2*(10*(-200*time*dt + 2.0)**2*exp(-100*(time*dt - 0.01)**2) - 2000*exp(-100*(time*dt - 0.01)**2))*src_spat(x, y) + 22.0*dt*(-200*time*dt + 2.0)*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + k01(x, y) + 11.0*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2)) + k00(x, y))\n",
+ " Stage 17: Eq(k1(x, y), dt*(1.0*dt**2*((-200*time*dt + 2.0)**2*exp(-100*(time*dt - 0.01)**2) - 200*exp(-100*(time*dt - 0.01)**2))*src_spat(x, y) + 2.0*dt*(-200*time*dt + 2.0)*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + Derivative(k00(x, y), (x, 2)) + Derivative(k00(x, y), (y, 2))) + k01(x, y))\n",
+ " Stage 18: Eq(u_multi_stage(t + dt, x, y), k0(x, y) + u_multi_stage(t + dt, x, y))\n",
+ " Stage 19: Eq(v_multi_stage(t + dt, x, y), k1(x, y) + v_multi_stage(t + dt, x, y))\n"
+ ]
+ }
+ ],
+ "source": [
+ "n_eq = 2 # Number of PDE unknowns (u, v)\n",
+ "\n",
+ "# Temporary variables for Runge-Kutta stages\n",
+ "k = [Function(name=f'k{i}', grid=grid, space_order=2, time_order=1, dtype=U_multi_stage[i].dtype) for i in range(n_eq)]\n",
+ "# Previous stage variables needed for temporary storage\n",
+ "k_old = [Function(name=f'k0{i}', grid=grid, space_order=2, time_order=1, dtype=U_multi_stage[i].dtype) for i in range(n_eq)]\n",
+ "\n",
+ "print(f\"\\nRunge-Kutta temporary variables:\")\n",
+ "print(f\" k: {[ki.name for ki in k]}\")\n",
+ "print(f\" k_old: {[ki.name for ki in k_old]}\")\n",
+ "\n",
+ "# Initialize expansion coefficients of the source term Taylor's series\n",
+ "e_p = [0] * deg\n",
+ "eta = 1\n",
+ "e_p[-1] = 1 / eta\n",
+ "print(f\"\\nExpansion coefficients initialized of the source term Taylor's series:\")\n",
+ "print(f\" e_p = {e_p}\")\n",
+ "print(f\" eta = {eta}\")\n",
+ "\n",
+ "# Initialize SSPRK coefficients (toy example)\n",
+ "alpha = [1]*4\n",
+ "print(f\"\\nSSPRK coefficients initialized:\")\n",
+ "print(f\" alpha = {alpha}\")\n",
+ "\n",
+ "\n",
+ "# Initialize list to store all stage equations\n",
+ "stage_eqs = []\n",
+ "\n",
+ "print(\"Building Runge-Kutta-like scheme...\")\n",
+ "\n",
+ "# Stage 0: Initialization\n",
+ "print(\" Stage 0: Initialization\")\n",
+ "stage_eqs.extend([Eq(k[i], U_multi_stage[i]) for i in range(n_eq)])\n",
+ "[stage_eqs.append(Eq(U_multi_stage[i].forward, U_multi_stage[i] * alpha[0])) for i in range(n_eq)]\n",
+ "\n",
+ "# Stage 1\n",
+ "print(\" Stage 1: First RK stage\")\n",
+ "[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]\n",
+ "src_lhs, e_p = source_inclusion(system_eqs_rhs, U_multi_stage, k_old, src_index, src_deriv, e_p, t, dt, n_eq)\n",
+ "[stage_eqs.append(Eq(k[j], k_old[j] + dt * src_lhs[j])) for j in range(n_eq)]\n",
+ "[stage_eqs.append(Eq(U_multi_stage[j].forward, U_multi_stage[j].forward + k[j] * alpha[1])) for j in range(n_eq)]\n",
+ "\n",
+ "# Stage 2\n",
+ "print(\" Stage 2: Second RK stage\")\n",
+ "[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]\n",
+ "src_lhs, e_p = source_inclusion(system_eqs_rhs, U_multi_stage, k_old, src_index, src_deriv, e_p, t, dt, n_eq)\n",
+ "[stage_eqs.append(Eq(k[j], k_old[j] + dt * src_lhs[j])) for j in range(n_eq)]\n",
+ "\n",
+ "# Stage 3 and final update\n",
+ "print(\" Stage 3: Final RK stage and update\")\n",
+ "[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]\n",
+ "src_lhs, _ = source_inclusion(system_eqs_rhs, U_multi_stage, k_old, src_index, src_deriv, e_p, t, dt, n_eq)\n",
+ "[stage_eqs.append(Eq(k[j], k_old[j] + dt * src_lhs[j])) for j in range(n_eq)]\n",
+ "[stage_eqs.append(Eq(U_multi_stage[j].forward, U_multi_stage[j].forward + k[j] * alpha[deg - 1])) for j in range(n_eq)]\n",
+ "\n",
+ "print(f\"\\nTotal equations created: {len(stage_eqs)}\")\n",
+ "print(\"Scheme structure:\")\n",
+ "for i, stage in enumerate(stage_eqs):\n",
+ " print(f\" Stage {i}: {stage}\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ef2048a1",
+ "metadata": {},
+ "source": [
+ "## 6. Create and Run Devito Operator\n",
+ "\n",
+ "Compile all the equations into a Devito Operator and execute the simulation with the specified parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "d64897d9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Creating Devito Operator...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Operator `Kernel` generated in 0.33 s\n",
+ " * lowering.Clusters: 0.17 s (52.7 %)\n",
+ " * specializing.Clusters: 0.10 s (31.0 %)\n",
+ " * lowering.IET: 0.11 s (34.1 %)\n",
+ "Flops reduction after symbolic optimization: [209 --> 106]\n",
+ "Operator `Kernel` fetched `/tmp/devito-jitcache-uid1000/ead4e1022d510052a8b9b1fb08d861635f2bdbdd.c` in 0.08 s from jit-cache\n",
+ "Allocating host memory for k0(205, 205) [328 KB]\n",
+ "Allocating host memory for k00(205, 205) [328 KB]\n",
+ "Allocating host memory for k01(205, 205) [328 KB]\n",
+ "Allocating host memory for k1(205, 205) [328 KB]\n",
+ "Allocating host memory for u_multi_stage(2, 205, 205) [657 KB]\n",
+ "Allocating host memory for v_multi_stage(2, 205, 205) [657 KB]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Operator successfully created!\n",
+ " Number of equations: 20\n",
+ " Grid spacing substitutions applied: {h_x: np.float64(0.005), h_y: np.float64(0.005)}\n",
+ "\n",
+ "Simulation parameters:\n",
+ " Time step (dt): 0.001\n",
+ " Maximum time: 2000\n",
+ "\n",
+ "Running simulation...\n"
+ ]
+ },
+ {
+ "ename": "InvalidArgument",
+ "evalue": "No value found for parameter time_size",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+ "\u001b[31mInvalidArgument\u001b[39m Traceback (most recent call last)",
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[11]\u001b[39m\u001b[32m, line 19\u001b[39m\n\u001b[32m 17\u001b[39m \u001b[38;5;66;03m# Execute the simulation\u001b[39;00m\n\u001b[32m 18\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mRunning simulation...\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[43mop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdt\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdt_value\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtime\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtime_max\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mSimulation completed successfully!\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 23\u001b[39m \u001b[38;5;66;03m# Display final wavefield shapes and some statistics\u001b[39;00m\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Trabajo/pos-doc/posdoc_fernando/devito/devito/operator/operator.py:878\u001b[39m, in \u001b[36mOperator.__call__\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 877\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, **kwargs):\n\u001b[32m--> \u001b[39m\u001b[32m878\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Trabajo/pos-doc/posdoc_fernando/devito/devito/operator/operator.py:995\u001b[39m, in \u001b[36mOperator.apply\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 993\u001b[39m \u001b[38;5;66;03m# Build the arguments list to invoke the kernel function\u001b[39;00m\n\u001b[32m 994\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._profiler.timer_on(\u001b[33m'\u001b[39m\u001b[33marguments-preprocess\u001b[39m\u001b[33m'\u001b[39m):\n\u001b[32m--> \u001b[39m\u001b[32m995\u001b[39m args = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43marguments\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 996\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m switch_log_level(comm=args.comm):\n\u001b[32m 997\u001b[39m \u001b[38;5;28mself\u001b[39m._emit_args_profiling(\u001b[33m'\u001b[39m\u001b[33marguments-preprocess\u001b[39m\u001b[33m'\u001b[39m)\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Trabajo/pos-doc/posdoc_fernando/devito/devito/operator/operator.py:773\u001b[39m, in \u001b[36mOperator.arguments\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 771\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.parameters:\n\u001b[32m 772\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m args.get(p.name) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m773\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m InvalidArgument(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mNo value found for parameter \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mp.name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m args\n",
+ "\u001b[31mInvalidArgument\u001b[39m: No value found for parameter time_size"
+ ]
+ }
+ ],
+ "source": [
+ "# Create the Devito Operator\n",
+ "print(\"Creating Devito Operator...\")\n",
+ "op = Operator(stage_eqs, subs=grid.spacing_map)\n",
+ "\n",
+ "print(\"Operator successfully created!\")\n",
+ "print(f\" Number of equations: {len(stage_eqs)}\")\n",
+ "print(f\" Grid spacing substitutions applied: {grid.spacing_map}\")\n",
+ "\n",
+ "# Define simulation parameters\n",
+ "dt_value = 0.001 # Time step size\n",
+ "time_max = 2000 # Maximum simulation time\n",
+ "\n",
+ "print(f\"\\nSimulation parameters:\")\n",
+ "print(f\" Time step (dt): {dt_value}\")\n",
+ "print(f\" Maximum time: {time_max}\")\n",
+ "\n",
+ "# Execute the simulation\n",
+ "print(\"\\nRunning simulation...\")\n",
+ "op(dt=dt_value, time=time_max)\n",
+ "\n",
+ "print(\"Simulation completed successfully!\")\n",
+ "\n",
+ "# Display final wavefield shapes and some statistics\n",
+ "print(f\"\\nFinal wavefield statistics:\")\n",
+ "for i, u in enumerate(U_multi_stage):\n",
+ " print(f\" {fun_labels[i]}:\")\n",
+ " print(f\" Shape: {u.data.shape}\")\n",
+ " print(f\" Max value: {np.max(u.data):.6f}\")\n",
+ " print(f\" Min value: {np.min(u.data):.6f}\")\n",
+ " print(f\" Mean value: {np.mean(u.data):.6f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4e028d49",
+ "metadata": {},
+ "source": [
+ "## 7. References\n",
+ "\n",
+ "\n",
+ "\n",
+ "[1] Al-Mohy AH, Higham NJ (2010) A new scaling and squaring algorithm for\n",
+ "the matrix exponential. SIAM Journal on Matrix Analysis and Applications\n",
+ "31(3):970–989\n",
+ "\n",
+ "\n",
+ "[2] Gottlieb S, Gottlieb LAJ (2003) Strong stability preserving properties of runge–kutta time discretization methods for linear constant coefficient operators. Journal of Scientific Computing 18(1):83–109"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "devito_b",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/devito/ir/equations/algorithms.py b/devito/ir/equations/algorithms.py
index ce844887aa..66b44e34f9 100644
--- a/devito/ir/equations/algorithms.py
+++ b/devito/ir/equations/algorithms.py
@@ -6,14 +6,15 @@
from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered,
frozendict)
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
- ConditionalDimension)
+ ConditionalDimension, MultiStage)
from devito.types.array import Array
from devito.types.basic import AbstractFunction
from devito.types.dimension import MultiSubDimension, Thickness
from devito.data.allocators import DataReference
from devito.logger import warning
-__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims']
+
+__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims']
def dimension_sort(expr):
@@ -95,6 +96,39 @@ def handle_indexed(indexed):
return ordering
+def lower_multistage(expressions, **kwargs):
+ """
+ Separating the multi-stage time-integrator scheme in stages:
+ * If the object is MultiStage, it creates the stages of the method.
+ """
+ return _lower_multistage(expressions, **kwargs)
+
+
+@singledispatch
+def _lower_multistage(expr, **kwargs):
+ """
+ Default handler for expressions that are not MultiStage.
+ Simply return them in a list.
+ """
+ return [expr]
+
+
+@_lower_multistage.register(MultiStage)
+def _(expr, **kwargs):
+ """
+ Specialized handler for MultiStage expressions.
+ """
+ return expr._evaluate(**kwargs)
+
+
+@_lower_multistage.register(Iterable)
+def _(exprs, **kwargs):
+ """
+ Handle iterables of expressions.
+ """
+ return sum([_lower_multistage(expr, **kwargs) for expr in exprs], [])
+
+
def lower_exprs(expressions, subs=None, **kwargs):
"""
Lowering an expression consists of the following passes:
diff --git a/devito/operations/solve.py b/devito/operations/solve.py
index 0203dbe26d..498ad376f9 100644
--- a/devito/operations/solve.py
+++ b/devito/operations/solve.py
@@ -7,6 +7,8 @@
from devito.finite_differences.derivative import Derivative
from devito.tools import as_tuple
+from devito.types.multistage import resolve_method
+
__all__ = ['solve', 'linsolve']
@@ -56,9 +58,12 @@ def solve(eq, target, **kwargs):
# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions
if len(sols) > 1:
- return target.new_from_mat(sols)
+ sols_temp = target.new_from_mat(sols)
else:
- return sols[0]
+ sols_temp = sols[0]
+
+ method = kwargs.get("method", None)
+ return sols_temp if method is None else resolve_method(method)(target, sols_temp)
def linsolve(expr, target, **kwargs):
diff --git a/devito/operator/operator.py b/devito/operator/operator.py
index 0d473fe6a2..7f5b769a77 100644
--- a/devito/operator/operator.py
+++ b/devito/operator/operator.py
@@ -17,7 +17,7 @@
InvalidOperator)
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
switch_log_level)
-from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
+from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims
from devito.ir.clusters import ClusterGroup, clusterize
from devito.ir.iet import (Callable, CInterface, EntryFunction, DeviceFunction,
FindSymbols, MetaCall, derive_parameters, iet_build)
@@ -40,7 +40,6 @@
disk_layer)
from devito.types.dimension import Thickness
-
__all__ = ['Operator']
@@ -337,6 +336,8 @@ def _lower_exprs(cls, expressions, **kwargs):
* Apply substitution rules;
* Shift indices for domain alignment.
"""
+ expressions = lower_multistage(expressions, **kwargs)
+
expand = kwargs['options'].get('expand', True)
# Specialization is performed on unevaluated expressions
diff --git a/devito/types/__init__.py b/devito/types/__init__.py
index 6ec8bdfd16..a8c1d3b224 100644
--- a/devito/types/__init__.py
+++ b/devito/types/__init__.py
@@ -22,3 +22,5 @@
from .relational import * # noqa
from .sparse import * # noqa
from .tensor import * # noqa
+
+from .multistage import * # noqa
diff --git a/devito/types/multistage.py b/devito/types/multistage.py
new file mode 100644
index 0000000000..55ab8f9826
--- /dev/null
+++ b/devito/types/multistage.py
@@ -0,0 +1,540 @@
+from devito.types.equation import Eq
+from devito.types.dense import Function, TimeFunction
+from devito.symbolics import uxreplace
+import numpy as np
+from devito.types.array import Array
+from types import MappingProxyType
+
+method_registry = {}
+
+
+def register_method(cls=None, *, aliases=None):
+ """
+ Register a time integration method class.
+
+ Parameters
+ ----------
+ cls : class, optional
+ The method class to register.
+ aliases : list of str, optional
+ Additional aliases for the method.
+ """
+ def decorator(cls):
+ # Register the class name
+ method_registry[cls.__name__] = cls
+
+ # Register any aliases
+ if aliases:
+ for alias in aliases:
+ method_registry[alias] = cls
+
+ return cls
+
+ if cls is None:
+ # Called as @register_method(aliases=['alias1'])
+ return decorator
+ else:
+ # Called as @register_method
+ return decorator(cls)
+
+
+def resolve_method(method):
+ """
+ Resolve a time integration method by name.
+
+ Parameters
+ ----------
+ method : str
+ Name or alias of the time integration method.
+
+ Returns
+ -------
+ class
+ The method class.
+
+ Raises
+ ------
+ ValueError
+ If the method is not found in the registry.
+ """
+ try:
+ return method_registry[method]
+ except KeyError:
+ available = sorted(method_registry.keys())
+ raise ValueError(
+ f"The time integrator '{method}' is not implemented. "
+ f"Available methods: {available}"
+ )
+
+
+def multistage_method(lhs, rhs, method, degree=None, source=None):
+ method_cls = resolve_method(method)
+ return method_cls(lhs, rhs, degree=degree, source=source)
+
+
+class MultiStage(Eq):
+ """
+ Abstract base class for multi-stage time integration methods
+ (e.g., Runge-Kutta schemes) in Devito.
+
+ This class represents a symbolic equation of the form `target = rhs`
+ and provides a mechanism to associate it with a time integration
+ scheme. The specific integration behavior must be implemented by
+ subclasses via the `_evaluate` method.
+
+ Parameters
+ ----------
+ lhs : expr-like
+ The left-hand side of the equation, typically a time-updated Function
+ (e.g., `u.forward`).
+ rhs : expr-like, optional
+ The right-hand side of the equation to integrate. Defaults to 0.
+ subdomain : SubDomain, optional
+ A subdomain over which the equation applies.
+ coefficients : dict, optional
+ Optional dictionary of symbolic coefficients for the integration.
+ implicit_dims : tuple, optional
+ Additional dimensions that should be treated implicitly in the equation.
+ **kwargs : dict
+ Additional keyword arguments, such as time integration method selection.
+
+ Notes
+ -----
+ Subclasses must override the `_evaluate()` method to return a sequence
+ of update expressions for each stage in the integration process.
+ """
+
+ def __new__(cls, lhs, rhs, degree=None, source=None, **kwargs):
+ # Normalize to lists first lhs and rhs
+ if not isinstance(lhs, (list, tuple)):
+ lhs = [lhs]
+ if not isinstance(rhs, (list, tuple)):
+ rhs = [rhs]
+
+ # Convert to tuples for immutability
+ lhs_tuple = tuple([i.function for i in lhs])
+ rhs_tuple = tuple(rhs)
+
+ obj = super().__new__(cls, lhs_tuple[0], rhs_tuple[0], **kwargs)
+
+ # Store all equations as immutable tuples
+ obj._eq = tuple(Eq(lhs, rhs) for lhs, rhs in zip(lhs_tuple, rhs_tuple))
+ obj._lhs = lhs_tuple
+ obj._rhs = rhs_tuple
+ obj._deg = degree
+ # Convert source to tuple of tuples for immutability
+ obj._src = tuple(tuple(item)
+ for item in source) if source is not None else None
+ obj._t = lhs_tuple[0].grid.time_dim
+ obj._dt = obj._t.spacing
+ obj._n_eq = len(lhs_tuple)
+
+ return obj
+
+ @property
+ def eq(self):
+ """Return the full tuple of equations."""
+ return self._eq
+
+ @property
+ def lhs(self):
+ """Return tuple of left-hand sides."""
+ return self._lhs
+
+ @property
+ def rhs(self):
+ """Return tuple of right-hand sides."""
+ return self._rhs
+
+ @property
+ def deg(self):
+ """Return the degree parameter."""
+ return self._deg
+
+ @property
+ def src(self):
+ """Return the source parameter as tuple of tuples (immutable)."""
+ return self._src
+
+ @property
+ def t(self):
+ """Return the time (t) parameter."""
+ return self._t
+
+ @property
+ def dt(self):
+ """Return the time step (dt) parameter."""
+ return self._dt
+
+ @property
+ def n_eq(self):
+ """Return the number of equations."""
+ return self._n_eq
+
+ def _evaluate(self, **kwargs):
+ raise NotImplementedError(
+ f"_evaluate() must be implemented in the subclass {self.__class__.__name__}")
+
+
+class RungeKutta(MultiStage):
+ """
+ Base class for explicit Runge-Kutta (RK) time integration methods defined
+ via a Butcher tableau.
+
+ This class handles the general structure of RK schemes by using
+ the Butcher coefficients (`a`, `b`, `c`) to expand a single equation into
+ a series of intermediate stages followed by a final update. Subclasses
+ must define `a`, `b`, and `c` as class attributes.
+
+ Parameters
+ ----------
+ a : tuple of tuple of float
+ The coefficient matrix representing stage dependencies.
+ b : tuple of float
+ The weights for the final combination step.
+ c : tuple of float
+ The time shifts for each intermediate stage (often the row sums of `a`).
+
+ Attributes
+ ----------
+ a : tuple[tuple[float, ...], ...]
+ Butcher tableau `a` coefficients (stage coupling).
+ b : tuple[float, ...]
+ Butcher tableau `b` coefficients (weights for combining stages).
+ c : tuple[float, ...]
+ Butcher tableau `c` coefficients (stage time positions).
+ s : int
+ Number of stages in the RK method, inferred from `b`.
+ """
+
+ CoeffsBC = tuple[float | np.number, ...]
+ CoeffsA = tuple[CoeffsBC, ...]
+
+ def __init__(self, a: CoeffsA, b: CoeffsBC, c: CoeffsBC, lhs, rhs, **kwargs) -> None:
+ self.a, self.b, self.c = a, b, c
+
+ @property
+ def s(self):
+ return len(self.b)
+
+ def _evaluate(self, **kwargs):
+ """
+ Generate the stage-wise equations for a Runge-Kutta time integration method.
+
+ This method takes a single equation of the form `Eq(u.forward, rhs)` and
+ expands it into a sequence of intermediate stage evaluations and a final
+ update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`.
+
+ Returns
+ -------
+ list of Devito Eq objects
+ A list of SymPy Eq objects representing:
+ - `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
+ - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
+ """
+
+ sregistry = kwargs.get('sregistry')
+ # Create temporary Arrays to hold each stage
+ k = [[Array(name=f'{sregistry.make_name(prefix='k')}', dimensions=self.lhs[j].grid.dimensions, grid=self.lhs[j].grid, dtype=self.lhs[j].dtype) for i in range(self.s)]
+ for j in range(self.n_eq)]
+
+ stage_eqs = []
+
+ # Build each stage
+ for i in range(self.s):
+ u_temp = [self.lhs[l] + self.dt * sum(aij * kj for aij, kj in zip(
+ self.a[i][:i], k[l][:i])) for l in range(self.n_eq)]
+ t_shift = self.t + self.c[i]
+
+ # Evaluate RHS at intermediate value
+ stage_rhs = [uxreplace(self.rhs[l], {**{self.lhs[m]: u_temp[m] for m in range(
+ self.n_eq)}, self.t: t_shift}) for l in range(self.n_eq)]
+ stage_eqs.extend([Eq(k[l][i], stage_rhs[l])
+ for l in range(self.n_eq)])
+
+ # Final update: u.forward = u + dt * sum(b_i * k_i)
+ u_next = [self.lhs[l] + self.dt *
+ sum(bi * ki for bi, ki in zip(self.b, k[l])) for l in range(self.n_eq)]
+ stage_eqs.extend([Eq(self.lhs[l].forward, u_next[l])
+ for l in range(self.n_eq)])
+
+ return stage_eqs
+
+
+@register_method(aliases=['RK44'])
+class RungeKutta44(RungeKutta):
+ """
+ Classic 4th-order Runge-Kutta (RK4) time integration method.
+
+ This class implements the classic explicit Runge-Kutta method of order 4 (RK44).
+
+ Attributes
+ ----------
+ a : tuple[tuple[float, ...], ...]
+ Coefficients of the `a` matrix for intermediate stage coupling.
+ b : tuple[float, ...]
+ Weights for final combination.
+ c : tuple[float, ...]
+ Time positions of intermediate stages.
+ """
+ a = ((0, 0, 0, 0),
+ (1/2, 0, 0, 0),
+ (0, 1/2, 0, 0),
+ (0, 0, 1, 0))
+ b = (1/6, 1/3, 1/3, 1/6)
+ c = (0, 1/2, 1/2, 1)
+
+ def __init__(self, lhs, rhs, **kwargs):
+ super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs)
+
+
+@register_method(aliases=['RK32'])
+class RungeKutta32(RungeKutta):
+ """
+ 3 stages 2nd-order Runge-Kutta (RK32) time integration method.
+
+ This class implements the 3-stages explicit Runge-Kutta method of order 2 (RK32).
+
+ Attributes
+ ----------
+ a : list[list[float]]
+ Coefficients of the `a` matrix for intermediate stage coupling.
+ b : list[float]
+ Weights for final combination.
+ c : list[float]
+ Time positions of intermediate stages.
+ """
+ a = ((0, 0, 0),
+ (1/2, 0, 0),
+ (0, 1/2, 0))
+ b = (0, 0, 1)
+ c = (0, 1/2, 1/2)
+
+ def __init__(self, lhs, rhs, **kwargs):
+ super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs)
+
+
+@register_method(aliases=['RK97'])
+class RungeKutta97(RungeKutta):
+ """
+ 9 stages 7th-order Runge-Kutta (RK97) time integration method.
+
+ This class implements the 9-stages explicit Runge-Kutta method of order 7 (RK97).
+
+ Attributes
+ ----------
+ a : list[list[float]]
+ Coefficients of the `a` matrix for intermediate stage coupling.
+ b : list[float]
+ Weights for final combination.
+ c : list[float]
+ Time positions of intermediate stages.
+ """
+ a = ((0, 0, 0, 0, 0, 0, 0, 0, 0),
+ (4/63, 0, 0, 0, 0, 0, 0, 0, 0),
+ (1/42, 1/14, 0, 0, 0, 0, 0, 0, 0),
+ (1/28, 0, 3/28, 0, 0, 0, 0, 0, 0),
+ (12551/19652, 0, -48363/19652, 10976/4913, 0, 0, 0, 0, 0),
+ (-36616931/27869184, 0, 2370277/442368, -255519173 /
+ 63700992, 226798819/445906944, 0, 0, 0, 0),
+ (-10401401/7164612, 0, 47383/8748, -4914455 /
+ 1318761, -1498465/7302393, 2785280/3739203, 0, 0, 0),
+ (181002080831/17500000000, 0, -14827049601/400000000, 23296401527134463/857600000000000,
+ 2937811552328081/949760000000000, -243874470411/69355468750, 2857867601589/3200000000000),
+ (-228380759/19257212, 0, 4828803/113948, -331062132205/10932626912, -12727101935/3720174304,
+ 22627205314560/4940625496417, -268403949/461033608, 3600000000000/19176750553961))
+ b = (95/2366, 0, 0, 3822231133/16579123200, 555164087/2298419200, 1279328256/9538891505,
+ 5963949/25894400, 50000000000/599799373173, 28487/712800)
+ c = (0, 4/63, 2/21, 1/7, 7/17, 13/24, 7/9, 91/100, 1)
+
+ def __init__(self, lhs, rhs, **kwargs):
+ super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs)
+
+
+@register_method(aliases=['HORK_EXP'])
+class HighOrderRungeKuttaExponential(MultiStage):
+ # In construction
+ """
+ n stages Runge-Kutta (HORK) time integration method.
+
+ This class implements the arbitrary high-order explicit Runge-Kutta method.
+
+ Attributes
+ ----------
+ a : list[list[float]]
+ Coefficients of the `a` matrix for intermediate stage coupling.
+ b : list[float]
+ Weights for final combination.
+ c : list[float]
+ Time positions of intermediate stages.
+ """
+
+ def source_derivatives(self, src_index, **kwargs):
+
+ # Compute the base wavelet function
+ f_deriv = [[src[1] for src in self.src]]
+
+ # Compute derivatives up to order p
+ for _ in range(self.deg - 1):
+ f_deriv.append([deriv.diff(self.t) for deriv in f_deriv[-1]])
+
+ f_deriv.reverse()
+ return f_deriv
+
+ def ssprk_alpha(self, mu=1):
+ """
+ Computes the coefficients for the Strong Stability Preserving Runge-Kutta (SSPRK) method.
+
+ Parameters:
+ mu : float
+ Theoretically, it should be the inverse of the CFL condition (typically mu=1 for best performance).
+ In practice, mu=1 works better.
+ degree : int
+ Degree of the polynomial used in the time-stepping scheme.
+
+ Returns:
+ numpy.ndarray
+ Array of SSPRK coefficients.
+ """
+
+ alpha = [0] * self.deg
+ alpha[0] = 1.0 # Initial coefficient
+
+ # recurrence relation to compute the HORK coefficients following the formula in Gottlieb and Gottlieb (2002)
+ for i in range(1, self.deg):
+ alpha[i] = 1 / (mu * (i + 1)) * alpha[i - 1]
+ alpha[1:i] = [1 / (mu * j) * alpha[j - 1] for j in range(1, i)]
+ alpha[0] = 1 - sum(alpha[1:i + 1])
+
+ return alpha
+
+ def source_inclusion(self, current_state, stage_values, e_p, **integration_params):
+ """
+ Include source terms in the time integration step.
+
+ This method applies source term contributions to the right-hand side
+ of the differential equations during time integration, accounting for
+ time derivatives of the source function and expansion coefficients.
+
+ Parameters
+ ----------
+ current_state : list
+ Current state variables (u).
+ stage_values : list
+ Current stage values (k).
+ e_p : list
+ Expansion coefficients for stability control.
+ **integration_params : dict
+ Integration parameters containing 't', 'dt', 'mu', 'src_index',
+ 'src_deriv', 'n_eq'.
+
+ Returns
+ -------
+ tuple
+ (modified_rhs, updated_e_p) - Updated right-hand side
+ equations and modified expansion coefficients.
+ """
+ # Extract integration parameters
+ mu = integration_params['mu']
+ src_index = integration_params['src_index']
+ src_deriv = integration_params['src_deriv']
+ n_eq = integration_params['n_eq']
+
+ # Build base right-hand side by substituting current stage values
+ src_lhs = [uxreplace(self.rhs[i], {current_state[m]: stage_values[m] for m in range(n_eq)})
+ for i in range(n_eq)]
+
+ # Apply source term contributions if sources exist
+ if self.src is not None:
+ p = len(src_deriv)
+
+ # Add source contributions for each derivative order
+ for i in range(p):
+ if e_p[i] != 0:
+ for j, idx in enumerate(src_index):
+ # Add weighted source derivative contribution
+ source_contribution = (self.src[j][0] * src_deriv[i][j].subs({self.t: self.t * self.dt}) * e_p[i])
+ src_lhs[idx] += source_contribution
+
+ # Update expansion coefficients for next stage
+ e_p = [e_p[i] + mu*self.dt*e_p[i + 1] for i in range(p - 1)] + [e_p[-1]]
+
+ return src_lhs, e_p
+
+ def _evaluate(self, **kwargs):
+ """
+ Generate the stage-wise equations for a Runge-Kutta time integration method.
+
+ This method takes a single equation of the form `Eq(u.forward, rhs)` and
+ expands it into a sequence of intermediate stage evaluations and a final
+ update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`.
+
+ Returns
+ -------
+ list of Eq
+ A list of SymPy Eq objects representing:
+ - `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
+ - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
+ """
+
+ sregistry = kwargs.get('sregistry')
+ # Create a temporary Array for each variable to save the time stages
+ # k = [Array(name=f'{sregistry.make_name(prefix='k')}', dimensions=u[i].grid.dimensions, grid=u[i].grid, dtype=u[i].dtype) for i in range(n_eq)]
+ k = [TimeFunction(name=f'{sregistry.make_name(prefix='k')}', grid=self.lhs[i].grid,
+ space_order=2, time_order=1, dtype=self.lhs[i].dtype) for i in range(self.n_eq)]
+ k_old = [TimeFunction(name=f'{sregistry.make_name(prefix='k')}', grid=self.lhs[i].grid,
+ space_order=2, time_order=1, dtype=self.lhs[i].dtype) for i in range(self.n_eq)]
+
+ # Compute SSPRK coefficients
+ mu = 1
+ alpha = self.ssprk_alpha(mu=mu)
+
+ # Initialize symbolic differentiation for source terms
+ field_map = {val: i for i, val in enumerate(self.lhs)}
+ if self.src is not None:
+ src_index = [field_map[src[2]] for src in self.src]
+ src_deriv = self.source_derivatives(src_index, **kwargs)
+ else:
+ src_index = None
+ src_deriv = None
+ print('src_index:', src_index)
+ print('src_deriv:', src_deriv)
+
+ # Expansion coefficients for stability control
+ e_p = [0] * self.deg
+ eta = 1
+ e_p[-1] = 1 / eta
+
+ stage_eqs = [Eq(ki, ui) for ki, ui in zip(k, self.lhs)]
+ stage_eqs.extend([Eq(lhs_i.forward, lhs_i*alpha[0]) for lhs_i in self.lhs])
+
+ # Prepare integration parameters for source inclusion
+ integration_params = {'mu': mu, 'src_index': src_index,
+ 'src_deriv': src_deriv, 'n_eq': self.n_eq}
+
+ # Build each stage
+ for i in range(1, self.deg - 1):
+ print('e_p:', e_p)
+ stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)])
+ src_lhs, e_p = self.source_inclusion(self.lhs, k_old, e_p, **integration_params)
+ stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)])
+ stage_eqs.extend([Eq(lhs_j.forward, lhs_j.forward+k_j*alpha[i]) for lhs_j, k_j in zip(self.lhs, k)])
+ print('e_p:', e_p)
+ # Final Runge-Kutta updates
+ stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)])
+ src_lhs, e_p = self.source_inclusion(self.lhs, k_old, e_p, **integration_params)
+ stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)])
+ print('e_p:', e_p)
+ stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)])
+ src_lhs, _ = self.source_inclusion(self.lhs, k_old, e_p, **integration_params)
+ stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)])
+
+ # Compute final approximation
+ stage_eqs.extend([Eq(lhs_j.forward, lhs_j.forward+k_j*alpha[self.deg-1]) for lhs_j, k_j in zip(self.lhs, k)])
+
+ for i in stage_eqs:
+ print(i)
+
+ return stage_eqs
+
+method_registry = MappingProxyType(method_registry)
\ No newline at end of file
diff --git a/tests/test_multistage.py b/tests/test_multistage.py
new file mode 100644
index 0000000000..b47c83d012
--- /dev/null
+++ b/tests/test_multistage.py
@@ -0,0 +1,503 @@
+import pytest
+import numpy as np
+import sympy as sym
+import tempfile
+import pickle
+import os
+
+from devito import (Grid, Function, TimeFunction,
+ Derivative, Operator, solve, Eq, configuration)
+from devito.types.multistage import multistage_method, MultiStage
+from devito.ir.support import SymbolRegistry
+from devito.ir.equations import lower_multistage
+
+configuration['log-level'] = 'DEBUG'
+
+
+def grid_parameters(extent=(10, 10), shape=(3, 3)):
+ grid = Grid(origin=(0, 0), extent=extent, shape=shape, dtype=np.float64)
+ x, y = grid.dimensions
+ dt = grid.stepping_dim.spacing
+ t = grid.time_dim
+ dx = extent[0] / (shape[0] - 1)
+ return grid, x, y, dt, t, dx
+
+
+def time_parameters(tn, dx, scale=1, t0=0):
+ t0, tn = 0.0, tn
+ dt0 = scale / dx**2
+ nt = int((tn - t0) / dt0)
+ dt0 = tn / nt
+ return tn, dt0, nt
+
+
+class Test_API:
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_pickles(self, time_int):
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3))
+
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u', 'v']
+ u = [TimeFunction(name=name, grid=grid, space_order=2,
+ time_order=1, dtype=np.float64) for name in fun_labels]
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[1, 1] = 1
+ src_temporal = (1 - 2 * (t * dt - 1)**2)
+
+ # PDE system (2D acoustic)
+ system_eqs_rhs = [u[1] + src_spatial * src_temporal,
+ Derivative(u[0], (x, 2), fd_order=2)
+ + Derivative(u[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal]
+
+ # Class of the time integration scheme
+ method = multistage_method(u, system_eqs_rhs, time_int)
+
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
+ pickle.dump(method, tmpfile)
+ filename = tmpfile.name
+
+ with open(filename, 'rb') as file:
+ method_saved = pickle.load(file)
+ os.remove(filename)
+
+ assert str(method) == str(
+ method_saved), "Mismatch in PDE after pickling"
+
+ op_orig = Operator(method)
+ op_saved = Operator(method_saved)
+
+ assert str(op_orig) == str(op_saved)
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_solve(self, time_int):
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3))
+
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u', 'v']
+ u = [TimeFunction(name=name, grid=grid, space_order=2,
+ time_order=1, dtype=np.float64) for name in fun_labels]
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[1, 1] = 1
+ src_temporal = (1 - 2 * (t * dt - 1)**2)
+
+ # PDE system (2D acoustic)
+ system_eqs_rhs = [u[1] + src_spatial * src_temporal,
+ Derivative(u[0], (x, 2), fd_order=2)
+ + Derivative(u[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal]
+
+ # Time integration scheme
+ pdes = [solve(system_eqs_rhs[i] - u[i], u[i], method=time_int)
+ for i in range(2)]
+
+ assert all(isinstance(i, MultiStage)
+ for i in pdes), "Not all elements are instances of MultiStage"
+
+
+class Test_lowering:
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_object(self, time_int):
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3))
+
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u', 'v']
+ u = [TimeFunction(name=name, grid=grid, space_order=2,
+ time_order=1, dtype=np.float64) for name in fun_labels]
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[1, 1] = 1
+ src_temporal = (1 - 2 * (t * dt - 1)**2)
+
+ # PDE system (2D acoustic)
+ system_eqs_rhs = [u[1] + src_spatial * src_temporal,
+ Derivative(u[0], (x, 2), fd_order=2)
+ + Derivative(u[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal]
+
+ # Class of the time integration scheme
+ pdes = multistage_method(u, system_eqs_rhs, time_int)
+
+ assert isinstance(
+ pdes, MultiStage), "Not all elements are instances of MultiStage"
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_lower_multistage(self, time_int):
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3))
+
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u', 'v']
+ u = [TimeFunction(name=name, grid=grid, space_order=2,
+ time_order=1, dtype=np.float64) for name in fun_labels]
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[1, 1] = 1
+ src_temporal = (1 - 2 * (t * dt - 1)**2)
+
+ # PDE system (2D acoustic)
+ system_eqs_rhs = [u[1] + src_spatial * src_temporal,
+ Derivative(u[0], (x, 2), fd_order=2)
+ + Derivative(u[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal]
+
+ # Class of the time integration scheme
+ pdes = multistage_method(u, system_eqs_rhs, time_int)
+
+ # Test the lowering process
+ sregistry = SymbolRegistry()
+
+ # Lower the multistage method - this should not raise an exception
+ lowered_eqs = lower_multistage(pdes, sregistry=sregistry)
+
+ # Validate the lowered equations
+ assert lowered_eqs is not None, "Lowering returned None"
+ assert len(lowered_eqs) > 0, "Lowering returned empty list"
+
+
+class Test_RK:
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_single_equation_integration(self, time_int):
+ """
+ Test single equation time integration with MultiStage methods.
+
+ This test verifies that time integration works correctly for the simplest case:
+ a single PDE with a single unknown function. This represents the most basic
+ MultiStage usage scenario (e.g., heat equation, simple wave equation).
+ """
+
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(
+ extent=(1, 1), shape=(200, 200))
+
+ # Define single unknown function
+ u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2,
+ time_order=1, dtype=np.float64)
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[1, 1] = 1
+ src_temporal = (1 - 2 * (t * dt - 1)**2)
+
+ # Single PDE: du/dt = ∇²u + source (diffusion/wave equation)
+ eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2)
+ + Derivative(u_multi_stage, (y, 2), fd_order=2)
+ + src_spatial * src_temporal)
+
+ # Store initial data for comparison
+ initial_data = u_multi_stage.data.copy()
+
+ # Time integration scheme - single equation MultiStage object
+ pde = multistage_method(u_multi_stage, eq_rhs, time_int)
+
+ # Run the operator
+ op = Operator([pde], subs=grid.spacing_map) # Operator expects a list
+ op(dt=0.01, time=1)
+
+ # Verify that computation actually occurred (data changed)
+ assert not np.array_equal(
+ u_multi_stage.data, initial_data), "Data should have changed"
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_decoupled_equations(self, time_int):
+ """
+ Test decoupled time integration where each equation gets its own MultiStage object.
+
+ This test verifies that time integration works when creating separate MultiStage
+ objects for each equation, as opposed to coupled integration where all equations
+ are handled by a single MultiStage object.
+ """
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(
+ extent=(1, 1), shape=(200, 200))
+
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u_multi_stage', 'v_multi_stage']
+ u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, dtype=np.float64)
+ for name in fun_labels]
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[1, 1] = 1
+ src_temporal = (1 - 2 * (t * dt - 1)**2)
+
+ # PDE system - each equation independent for decoupled integration
+ system_eqs_rhs = [u_multi_stage[1] + src_spatial * src_temporal,
+ Derivative(u_multi_stage[0], (x, 2), fd_order=2)
+ + Derivative(u_multi_stage[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal]
+
+ # Store initial data for comparison
+ initial_data = [u.data.copy() for u in u_multi_stage]
+
+ # Time integration scheme - create separate MultiStage objects (decoupled)
+ pdes = [multistage_method(u_multi_stage[i], system_eqs_rhs[i], time_int)
+ for i in range(len(fun_labels))]
+
+ # Run the operator
+ op = Operator(pdes, subs=grid.spacing_map)
+ op(dt=0.01, time=1)
+
+ # Verify that computation actually occurred (data changed)
+ for i, u in enumerate(u_multi_stage):
+ assert not np.array_equal(
+ u.data, initial_data[i]), f"Data should have changed for variable {i}"
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_coupled_op_computing(self, time_int):
+ """
+ Test coupled time integration where all equations are handled by a single MultiStage object.
+
+ This test verifies that time integration works correctly when multiple coupled equations
+ are integrated together within a single MultiStage object, allowing for proper coupling
+ between the equations during the time stepping process.
+ """
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(
+ extent=(1, 1), shape=(200, 200))
+
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u_multi_stage', 'v_multi_stage']
+ u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1,
+ dtype=np.float64) for name in fun_labels]
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[1, 1] = 1
+ src_temporal = (1 - 2 * (t * dt - 1)**2)
+
+ # PDE system - coupled acoustic wave equations
+ system_eqs_rhs = [u_multi_stage[1], # velocity equation: du/dt = v
+ Derivative(u_multi_stage[0], (x, 2), fd_order=2)
+ + Derivative(u_multi_stage[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal] # displacement equation: dv/dt = ∇²u + source
+
+ # Store initial data for comparison
+ initial_data = [u.data.copy() for u in u_multi_stage]
+
+ # Time integration scheme - single coupled MultiStage object
+ pdes = multistage_method(u_multi_stage, system_eqs_rhs, time_int)
+
+ # Run the operator
+ op = Operator(pdes, subs=grid.spacing_map)
+ op(dt=0.01, time=1)
+
+ # Verify that computation actually occurred (data changed)
+ for i, u in enumerate(u_multi_stage):
+ assert not np.array_equal(
+ u.data, initial_data[i]), f"Data should have changed for variable {i}"
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_low_order_convergence_ODE(self, time_int):
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(extent=(10, 10), shape=(3, 3))
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[:] = 1
+ src_temporal = 2 * t * dt
+
+ # Time axis
+ tn, dt0, nt = time_parameters(3.0, dx, scale=1e-2)
+
+ # Time integrator solution
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u', 'v']
+ u_multi_stage = [TimeFunction(name=name + '_multi_stage', grid=grid, space_order=2, time_order=1,
+ dtype=np.float64) for name in fun_labels]
+
+ # PDE (2D acoustic)
+ eq_rhs = [
+ (-1.5 * u_multi_stage[0] + 0.5 * u_multi_stage[1]) * src_spatial * src_temporal,
+ (-1.5 * u_multi_stage[1] + 0.5 * u_multi_stage[0]) * src_spatial * src_temporal]
+ u_multi_stage[0].data[0, :] = 1
+
+ # Time integration scheme
+ pdes = multistage_method(u_multi_stage, eq_rhs, time_int)
+ op = Operator(pdes, subs=grid.spacing_map)
+ op(dt=dt0, time=nt)
+
+ # exact solution
+ d = np.array([-1, -2])
+ a = np.array([[1, 1], [1, -1]])
+ exact_sol = np.dot(
+ np.dot(a, np.diag(np.exp(d * tn**2))), np.linalg.inv(a))
+ assert np.max(np.abs(exact_sol[0, 0] - u_multi_stage[0].data[0, :])
+ ) < 10 ** -5, "the method is not converging to the solution"
+
+ @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97'])
+ def test_low_order_convergence(self, time_int):
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(
+ extent=(1000, 1000), shape=(201, 201))
+
+ # Medium velocity model
+ vel = Function(name=f"vel_{time_int}",
+ grid=grid, space_order=2, dtype=np.float64)
+ vel.data[:] = 1.0
+ vel.data[150:, :] = 1.3
+
+ # Source definition
+ src_spatial = Function(
+ name=f"src_spat_{time_int}", grid=grid, space_order=2, dtype=np.float64)
+ src_spatial.data[100, 100] = 1 / dx**2
+ f0 = 0.01
+ src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1 / f0))**2) * \
+ sym.exp(-(np.pi * f0 * (t * dt - 1 / f0))**2)
+
+ # Time axis
+ tn, dt0, nt = time_parameters(500.0, dx, scale=1e-1 * np.max(vel.data))
+
+ # Time integrator solution
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u', 'v']
+ u_multi_stage = [
+ TimeFunction(name=f"{name}_multi_stage_{time_int}", grid=grid, space_order=2, time_order=1,
+ dtype=np.float64) for name in fun_labels]
+
+ # PDE (2D acoustic)
+ eq_rhs = [u_multi_stage[1], (Derivative(u_multi_stage[0], (x, 2), fd_order=2)
+ + Derivative(u_multi_stage[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal) * vel**2]
+
+ # Time integration scheme
+ pdes = multistage_method(u_multi_stage, eq_rhs, time_int)
+ op = Operator(pdes, subs=grid.spacing_map)
+ op(dt=dt0, time=nt)
+
+ # Devito's default solution
+ u = [TimeFunction(name=f"{name}_{time_int}", grid=grid, space_order=2,
+ time_order=1, dtype=np.float64) for name in fun_labels]
+
+ # PDE (2D acoustic)
+ eq_rhs = [u[1], (Derivative(u[0], (x, 2), fd_order=2) + Derivative(u[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal) * vel**2]
+
+ # Time integration scheme
+ pdes = [Eq(u[i].forward, solve(Eq(u[i].dt - eq_rhs[i]), u[i].forward))
+ for i in range(len(fun_labels))]
+ op = Operator(pdes, subs=grid.spacing_map)
+ op(dt=dt0, time=nt)
+
+ assert (np.linalg.norm(u[0].data[0, :] - u_multi_stage[0].data[0, :]) / np.linalg.norm(
+ u[0].data[0, :])) < 10**-1, "the method is not converging to the solution"
+
+
+class Test_HORK:
+
+ def test_coupled_op_computing_exp(self, time_int='HORK_EXP'):
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(
+ extent=(1, 1), shape=(201, 201))
+
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u_multi_stage', 'v_multi_stage']
+ u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1,
+ dtype=np.float64) for name in fun_labels]
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=np.float64)
+ src_spatial.data[100, 100] = 1
+ src_temporal = sym.exp(- 100 * (t - 0.01) ** 2)
+
+ # PDE system
+ system_eqs_rhs = [u_multi_stage[1],
+ Derivative(u_multi_stage[0], (x, 2), fd_order=2)
+ + Derivative(u_multi_stage[0], (y, 2), fd_order=2)]
+
+ # Store initial data for comparison
+ initial_data = [u.data.copy() for u in u_multi_stage]
+
+ src = [[src_spatial, src_temporal, u_multi_stage[0]],
+ [src_spatial, src_temporal * 10, u_multi_stage[0]],
+ [src_spatial, src_temporal, u_multi_stage[1]]]
+
+ # Time integration scheme
+ pdes = multistage_method(
+ u_multi_stage, system_eqs_rhs, time_int, degree=4, source=src)
+ op = Operator(pdes, subs=grid.spacing_map)
+ op(dt=0.001, time=2000)
+
+ # Verify that computation actually occurred (data changed)
+ for i, u in enumerate(u_multi_stage):
+ assert not np.array_equal(
+ u.data, initial_data[i]), f"Data should have changed for variable {i}"
+
+
+ @pytest.mark.parametrize('degree', list(range(3, 11)))
+ def test_HORK_EXP_convergence(self, degree):
+ # Grid setup
+ grid, x, y, dt, t, dx = grid_parameters(
+ extent=(1000, 1000), shape=(201, 201))
+
+ # Medium velocity model
+ vel = Function(name="vel", grid=grid, space_order=2, dtype=np.float64)
+ vel.data[:] = 1.0
+ vel.data[150:, :] = 1.3
+
+ # Source definition
+ src_spatial = Function(name="src_spat", grid=grid,
+ space_order=2, dtype=np.float64)
+ src_spatial.data[100, 100] = 1 / dx**2
+ f0 = 0.01
+ src_temporal = (1 - 2 * (np.pi * f0 * (t - 1 / f0))**2) * sym.exp(-(np.pi * f0 * (t - 1 / f0))**2)
+
+ # Time axis
+ tn, dt0, nt = time_parameters(500.0, dx, scale=np.max(vel.data))
+
+ # Time integrator solution
+ # Define wavefield unknowns: u (displacement) and v (velocity)
+ fun_labels = ['u_sol', 'v_sol']
+ u_multi_stage = [TimeFunction(name=name + '_multi_stage', grid=grid, space_order=2, time_order=1,
+ dtype=np.float64) for name in fun_labels]
+
+ # PDE (2D acoustic)
+ eq_rhs = [u_multi_stage[1], (Derivative(u_multi_stage[0],(x,2), fd_order=2) + Derivative(
+ u_multi_stage[0], (y,2), fd_order=2)) * vel**2]
+
+ src = [[src_spatial * vel**2, src_temporal, u_multi_stage[1]]]
+
+ # Time integration scheme
+ pdes = multistage_method(
+ u_multi_stage, eq_rhs, 'HORK_EXP', source=src, degree=degree)
+ op = Operator(pdes, subs=grid.spacing_map)
+ op(dt=dt0, time=nt)
+
+ # Devito's default solution
+ u = [TimeFunction(name=name, grid=grid, space_order=2,
+ time_order=1, dtype=np.float64) for name in fun_labels]
+
+ # PDE (2D acoustic)
+ src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1 / f0))**2) * sym.exp(-(np.pi * f0 * (t * dt - 1 / f0))**2)
+ eq_rhs = [u[1], (Derivative(u[0], (x, 2), fd_order=2)
+ + Derivative(u[0], (y, 2), fd_order=2)
+ + src_spatial * src_temporal) * vel**2]
+
+ # Time integration scheme
+ pdes = [Eq(u[i].forward, solve(Eq(u[i].dt - eq_rhs[i]), u[i].forward))
+ for i in range(len(fun_labels))]
+ op = Operator(pdes, subs=grid.spacing_map)
+ op(dt=dt0, time=nt)
+
+ assert (np.linalg.norm(u[0].data[0, :] - u_multi_stage[0].data[0, :]) / np.linalg.norm(
+ u[0].data[0, :])) < 10**-1, "the method is not converging to the solution"
\ No newline at end of file
diff --git a/tests/test_saving_multistage.pkl b/tests/test_saving_multistage.pkl
new file mode 100644
index 0000000000..a88f40675c
Binary files /dev/null and b/tests/test_saving_multistage.pkl differ