diff --git a/MFE_time_size.ipynb b/MFE_time_size.ipynb new file mode 100644 index 0000000000..91982d13da --- /dev/null +++ b/MFE_time_size.ipynb @@ -0,0 +1,637 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f11a9f6b", + "metadata": {}, + "source": [ + "# Multi-Stage Minimum Failure Example (MFE) Time Stepping Simulation\n", + "\n", + "This notebook implements a multi-stage Runge-Kutta-like time integration scheme for solving a PDE system using Devito.\n", + "\n", + "## Overview\n", + "\n", + "To apply multi-stage methods to a PDE system, we first have to formulated as a first-order in time system:\n", + "\n", + "\n", + "$\\frac{d}{dt}\\boldsymbol{U}(\\boldsymbol{x},t)=\\boldsymbol{HU}(\\boldsymbol{x},t)+\\boldsymbol{f}(\\boldsymbol{x},t),\\hspace{3cm}(1)$\n", + "\n", + "where $\\boldsymbol{x}$ is a vector with the spatial variables (x in 1D, (x,y) in 2D, (x,y,z in 3D)), $\\boldsymbol{U}$ is a vector of real-valued functions, $\\boldsymbol{H}$ is a spatial operator that contains spatial derivatives and constant coefficients, and $\\boldsymbol{f}$ is a known vector of real-valued functions.\n", + "\n", + "Expanding $\\boldsymbol{f}(x,y,t)$ in its Taylor 's series, we can approximate the solution of system (1) by the matrix exponential (see [[1](#ref-higham2010)])\n", + "\n", + "$\\hat{\\boldsymbol{U}}(\\boldsymbol{x},t)\\approx[\\boldsymbol{I_p}\\;0]e^{t\\hat{\\boldsymbol{H}}}\\begin{bmatrix}\\boldsymbol{U}(\\boldsymbol{x},t_0)\\\\ \\boldsymbol{e_p}\\end{bmatrix},$\n", + "\n", + "where $\\boldsymbol{e_p}\\in\\mathbb{R}^p$ is the eigenvector with zero in all its entries exept the last one, that equals 1, $\\hat{\\boldsymbol{U}}$ is the vector function $\\boldsymbol{U}$ with $p$ extra zeros at the end, $\\boldsymbol{I_p}$ is the identity matrix of size $p$, and $\\hat{\\boldsymbol{H}}$ is the matrix\n", + "\n", + "$\\hat{\\boldsymbol{H}}=\\begin{bmatrix}\\boldsymbol{H} & \\boldsymbol{W}\\\\ 0 & \\boldsymbol{J_p}\\end{bmatrix}, \\hspace{5.265cm}(2)$\n", + "\n", + "where $\\boldsymbol{J_p}$ is the zero matrix with an upper diagonal of ones\n", + "\n", + "$\\boldsymbol{J_p}=\\begin{bmatrix}0 & 1 & 0 & 0 & \\dots & 0\\\\ 0 & 0 & 1 & 0 & \\dots & 0\\\\ &&\\ddots&&&\\vdots \n", + "\\\\ 0 & 0 & 0 & 0 & \\dots & 1 \\\\0 & 0 & 0 & 0 & \\dots & 0\\end{bmatrix},$\n", + "\n", + "and $\\boldsymbol{W}$ contains the information of the vector function $\\boldsymbol{f}(\\boldsymbol{x},t)$ derivatives\n", + "\n", + "$\\boldsymbol{W}=\\begin{bmatrix}\\frac{\\partial^{p-1}}{\\partial t^{p-1}}\\boldsymbol{f}(\\boldsymbol{x},t_0)\\bigg\\vert \\frac{\\partial^{p-2}}{\\partial t^{p-2}}\\boldsymbol{f}(\\boldsymbol{x},t_0)\\bigg\\vert \\dots \\bigg\\vert \\frac{\\partial}{\\partial t}\\boldsymbol{f}(\\boldsymbol{x},t_0)\\bigg\\vert \\boldsymbol{f}(\\boldsymbol{x},t_0)\\end{bmatrix}.$\n", + "\n", + "Then, we approximate the solution operator in (2) with the m-stage Runge-Kutta (RK) method from [[2](#ref-gottlieb2003)]\n", + "\n", + "\\begin{align*}\n", + " \\boldsymbol{k}_0&=\\boldsymbol{u}_n\\\\\n", + " \\boldsymbol{k}_i&=\\left(\\boldsymbol{I}_{n\\times n}+\\Delta t \\hat{\\boldsymbol{H}}\\right)\\boldsymbol{k}_{i-1},\\quad i=1\\dots m-1\\\\\n", + " \\boldsymbol{k}_m&=\\sum\\limits_{i=0}^{m-2}\\alpha_i\\boldsymbol{k}_i+\\alpha_{m-1}\\left(\\boldsymbol{I}_{n\\times n}+\\Delta t \\hat{\\boldsymbol{H}}\\right)\\boldsymbol{k}_{m-1}\\\\\n", + " \\boldsymbol{u}_{n+1}&=\\boldsymbol{k}_m, \n", + "\\end{align*}\n", + "\n", + "where $\\alpha_i$ are the coefficients of the Runge-Kutta and have a straightforward computation.\n", + "\n", + "So, for each time step we have to construct the matrix $\\hat{\\boldsymbol{H}}$, where we only the submatrix $\\boldsymbol{W}$ change, and apply the RK method to the vector $\\boldsymbol{u}_n=[\\boldsymbol{U}(\\boldsymbol{x},t_n)\\;\\; \\boldsymbol{e_p}]^T$.\n", + "\n", + "So, an outline of the implementation is:\n", + "\n", + "- **Environmental variables**: Set up the grid, and spatial and time variables\n", + "- **PDE System definition**: Define $\\boldsymbol{f}(\\boldsymbol{x},t)$ and the system of equations\n", + "- **Compute the derivatives of the source term**: symbolic computing of $\\boldsymbol{f}(\\boldsymbol{x},t)$ derivatives\n", + "- **Construct the operator $\\hat{\\boldsymbol{H}}$**: application the application of operator $\\hat{\\boldsymbol{H}}$\n", + "- **Implementation of the RK method**: define the required equations of the RK method to pass to Devito's operator. For this particular example, we'll use $m=3$ and $\\alpha_i=1,\\;\\forall i\\in\\{0,1,2\\}$.\n", + "- **Create and run Devito operator**: Executing the operator constructed with the RK equations" + ] + }, + { + "cell_type": "markdown", + "id": "274432cf", + "metadata": {}, + "source": [ + "## 0. Import Required Libraries\n", + "\n", + "First, we import all necessary libraries including NumPy for numerical operations, SymPy for symbolic mathematics, and Devito components for finite difference operations." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "930afc0c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import sympy as sym\n", + "from devito import (Grid, Function, TimeFunction,\n", + " Derivative, Operator, Eq)\n", + "from devito import configuration\n", + "from devito.symbolics import uxreplace" + ] + }, + { + "cell_type": "markdown", + "id": "3a3b1e0e", + "metadata": {}, + "source": [ + "## 1. Environmental variables\n", + "\n", + "Configure the simulation environment, including logging level, domain parameters, and computational grid setup." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0d2a32c5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Grid created: (201, 201) points over domain (1, 1)\n", + "Spatial dimensions: x=x, y=y\n", + "Time dimension: t=time, dt=dt\n" + ] + } + ], + "source": [ + "# Configure Devito logging\n", + "configuration['log-level'] = 'DEBUG'\n", + "\n", + "# Simulation parameters\n", + "extent = (1, 1) # Physical domain size\n", + "shape = (201, 201) # Grid resolution\n", + "origin = (0, 0) # Domain origin\n", + "\n", + "# Create computational grid\n", + "grid = Grid(origin=origin, extent=extent, shape=shape, dtype=np.float64)\n", + "x, y = grid.dimensions\n", + "t, dt = grid.time_dim, grid.stepping_dim.spacing\n", + "\n", + "print(f\"Grid created: {shape} points over domain {extent}\")\n", + "print(f\"Spatial dimensions: x={x}, y={y}\")\n", + "print(f\"Time dimension: t={t}, dt={dt}\")" + ] + }, + { + "cell_type": "markdown", + "id": "d1d7dfd2", + "metadata": {}, + "source": [ + "## 2. PDE System definition\n", + "\n", + "Create the TimeFunction objects for the wavefield variables (displacement u and velocity v) and define the source terms with both spatial and temporal characteristics." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5c42f929", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Allocating host memory for src_spat(205, 205) [328 KB]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wavefield functions created:\n", + " u_multi_stage: u_multi_stage(t, x, y)\n", + " v_multi_stage: v_multi_stage(t, x, y)\n", + "\n", + "Source spatial function: src_spat(x, y)\n", + "Source temporal function: exp(-100*(time - 0.01)**2)\n", + "\n", + "PDE system matrix H:\n", + " du_multi_stage/dt = v_multi_stage(t, x, y)\n", + " dv_multi_stage/dt = Derivative(u_multi_stage(t, x, y), (x, 2)) + Derivative(u_multi_stage(t, x, y), (y, 2))\n" + ] + } + ], + "source": [ + "# Define wavefield unknowns: u (displacement) and v (velocity)\n", + "fun_labels = ['u_multi_stage', 'v_multi_stage']\n", + "U_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, dtype=np.float64) for name in fun_labels]\n", + "\n", + "print(\"Wavefield functions created:\")\n", + "for i, u in enumerate(U_multi_stage):\n", + " print(f\" {fun_labels[i]}: {u}\")\n", + "\n", + "# Source definition\n", + "src_spatial = Function(name=\"src_spat\", grid=grid, space_order=2, dtype=np.float64)\n", + "src_spatial.data[100, 100] = 1 # Point source at grid center\n", + "src_temporal = sym.exp(- 100 * (t - 0.01) ** 2) # Gaussian pulse\n", + "\n", + "print(f\"\\nSource spatial function: {src_spatial}\")\n", + "print(f\"Source temporal function: {src_temporal}\")\n", + "\n", + "# PDE right-hand side: du/dt = v, dv/dt = ∇²u\n", + "system_eqs_rhs = [U_multi_stage[1], # du/dt = v\n", + " Derivative(U_multi_stage[0], (x, 2), fd_order=2) +\n", + " Derivative(U_multi_stage[0], (y, 2), fd_order=2)] # dv/dt = ∇²u\n", + "\n", + "# Source coupling: [spatial, temporal, associated variable]\n", + "src = [[src_spatial, src_temporal, U_multi_stage[0]],\n", + " [src_spatial, src_temporal * 10, U_multi_stage[0]],\n", + " [src_spatial, src_temporal, U_multi_stage[1]]]\n", + "\n", + "print(f\"\\nPDE system matrix H:\")\n", + "for i, rhs in enumerate(system_eqs_rhs):\n", + " print(f\" d{fun_labels[i]}/dt = {rhs}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "38f6faa0", + "metadata": {}, + "source": [ + "## 3. Compute the derivatives of the source term\n", + "\n", + "Implement the core helper function that compute time derivatives of source and calculate the source derivatives up to the specified degree (deg=3)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "046482b7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - source_derivatives(): Computes time derivatives of source wavelet\n", + "Source index mapping:\n", + " u_multi_stage(t, x, y) -> 0\n", + " v_multi_stage(t, x, y) -> 1\n", + "\n", + "Source indices: [0, 0, 1]\n", + "\n", + "Source derivatives computed up to degree 3\n", + "Number of derivative levels: 3\n", + "Sample derivative expressions:\n", + " Level 2: [(2.0 - 200*time)**2*exp(-100*(time - 0.01)**2) - 200*exp(-100*(time - 0.01)**2), 10*(2.0 - 200*time)**2*exp(-100*(time - 0.01)**2) - 2000*exp(-100*(time - 0.01)**2), (2.0 - 200*time)**2*exp(-100*(time - 0.01)**2) - 200*exp(-100*(time - 0.01)**2)]\n", + " Level 1: [(2.0 - 200*time)*exp(-100*(time - 0.01)**2), 10*(2.0 - 200*time)*exp(-100*(time - 0.01)**2), (2.0 - 200*time)*exp(-100*(time - 0.01)**2)]\n", + " Level 0: [exp(-100*(time - 0.01)**2), 10*exp(-100*(time - 0.01)**2), exp(-100*(time - 0.01)**2)]\n" + ] + } + ], + "source": [ + "def source_derivatives(deg, src, src_index, t):\n", + " \"\"\"\n", + " Compute time derivatives of the source up to given degree.\n", + " \n", + " Parameters:\n", + " -----------\n", + " deg : int\n", + " Degree of derivatives to compute\n", + " src : list\n", + " List of source terms\n", + " src_index : list\n", + " Indices for source terms\n", + " t : symbol\n", + " Time symbol\n", + " \n", + " Returns:\n", + " --------\n", + " f_deriv : list\n", + " List of derivative expressions\n", + " \"\"\"\n", + " f_deriv = [[src[i][1] for i in range(len(src))]]\n", + " \n", + " # Compute derivatives up to order p\n", + " for _ in range(deg - 1):\n", + " f_deriv.append([f_deriv[-1][i].diff(t) for i in range(len(src_index))])\n", + " \n", + " f_deriv.reverse()\n", + " return f_deriv\n", + "\n", + "print(\" - source_derivatives(): Computes time derivatives of source wavelet\")\n", + "\n", + "# Create mapping from wavefield variables to indices\n", + "src_index_map = {val: i for i, val in enumerate(U_multi_stage)}\n", + "print(\"Source index mapping:\")\n", + "for var, idx in src_index_map.items():\n", + " print(f\" {var} -> {idx}\")\n", + "\n", + "# Extract source indices based on associated variables\n", + "src_index = [src_index_map[val] for val in [src[i][2] for i in range(len(src))]]\n", + "print(f\"\\nSource indices: {src_index}\")\n", + "\n", + "# Degree of derivatives to compute\n", + "deg = 3\n", + "\n", + "# Compute source derivatives\n", + "src_deriv = source_derivatives(deg, src, src_index, t)\n", + "print(f\"\\nSource derivatives computed up to degree {deg}\")\n", + "print(f\"Number of derivative levels: {len(src_deriv)}\")\n", + "print(\"Sample derivative expressions:\")\n", + "for i, deriv in enumerate(src_deriv):\n", + " print(f\" Level {deg-1-i}: {deriv}\")" + ] + }, + { + "cell_type": "markdown", + "id": "60c485ae", + "metadata": {}, + "source": [ + "## 4.Construct the operator $\\hat{\\boldsymbol{H}}$\n", + "\n", + " Application of the spatial operator to the vector formed by [u e_p]^T." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ccd0fe2a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - source_inclusion(): Apply the spatial operator of the PDE system at RK stages\n" + ] + } + ], + "source": [ + "# and handle source term inclusion in the PDE system at each Runge-Kutta stage\n", + "def source_inclusion(rhs, u, k, src_index, src_deriv, e_p, t, dt, n_eq):\n", + " \"\"\"\n", + " Add source terms to the PDE system at each Runge-Kutta stage.\n", + " \n", + " Parameters:\n", + " -----------\n", + " rhs : list\n", + " Right-hand side of PDE system without sources\n", + " u : list \n", + " Wavefield variables\n", + " k : list\n", + " Runge-Kutta stage variables\n", + " src_index : list\n", + " Source indices\n", + " src_deriv : list\n", + " Source derivatives\n", + " e_p : list\n", + " Expansion coefficients of the source term Taylor's series\n", + " t : symbol\n", + " Time symbol\n", + " dt : symbol\n", + " Time step symbol\n", + " n_eq : int\n", + " Number of equations\n", + " \n", + " Returns:\n", + " --------\n", + " src_lhs : list\n", + " Operator application to the vector [u, e_p]^T including source terms\n", + " e_p : list\n", + " Updated expansion coefficients of the source term Taylor's series\n", + " \"\"\"\n", + " # Replace wavefield variables with stage variables\n", + " src_lhs = [uxreplace(rhs[i], {u[m]: k[m] for m in range(n_eq)}) for i in range(n_eq)]\n", + " \n", + " p = len(src_deriv)\n", + " \n", + " # Add source contributions\n", + " for i in range(p):\n", + " if e_p[i] != 0:\n", + " for j in range(len(src_index)):\n", + " src_lhs[src_index[j]] += src[j][0] * src_deriv[i][j].subs({t: t * dt}) * e_p[i]\n", + " \n", + " # Update expansion coefficients of the source term Taylor's series\n", + " e_p = [e_p[i] + dt * e_p[i + 1] for i in range(p - 1)] + [e_p[-1]]\n", + " \n", + " return src_lhs, e_p\n", + "\n", + "print(\" - source_inclusion(): Apply the spatial operator of the PDE system at RK stages\")" + ] + }, + { + "cell_type": "markdown", + "id": "a462c0a9", + "metadata": {}, + "source": [ + "## 5. Implementation of the RK method\n", + "\n", + "Construct the multi-stage time integration scheme with initialization, multiple RK stages, and final update. This implements a toy example of a class of High-Order Runge-Kutta (HORK) methods, with proper source term integration." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0543e6d9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Runge-Kutta temporary variables:\n", + " k: ['k0', 'k1']\n", + " k_old: ['k00', 'k01']\n", + "\n", + "Expansion coefficients initialized of the source term Taylor's series:\n", + " e_p = [0, 0, 1.0]\n", + " eta = 1\n", + "\n", + "SSPRK coefficients initialized:\n", + " alpha = [1, 1, 1, 1]\n", + "Building Runge-Kutta-like scheme...\n", + " Stage 0: Initialization\n", + " Stage 1: First RK stage\n", + " Stage 2: Second RK stage\n", + " Stage 3: Final RK stage and update\n", + "\n", + "Total equations created: 20\n", + "Scheme structure:\n", + " Stage 0: Eq(k0(x, y), u_multi_stage(t, x, y))\n", + " Stage 1: Eq(k1(x, y), v_multi_stage(t, x, y))\n", + " Stage 2: Eq(u_multi_stage(t + dt, x, y), u_multi_stage(t, x, y))\n", + " Stage 3: Eq(v_multi_stage(t + dt, x, y), v_multi_stage(t, x, y))\n", + " Stage 4: Eq(k00(x, y), k0(x, y))\n", + " Stage 5: Eq(k01(x, y), k1(x, y))\n", + " Stage 6: Eq(k0(x, y), dt*(k01(x, y) + 11.0*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2)) + k00(x, y))\n", + " Stage 7: Eq(k1(x, y), dt*(src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + Derivative(k00(x, y), (x, 2)) + Derivative(k00(x, y), (y, 2))) + k01(x, y))\n", + " Stage 8: Eq(u_multi_stage(t + dt, x, y), k0(x, y) + u_multi_stage(t + dt, x, y))\n", + " Stage 9: Eq(v_multi_stage(t + dt, x, y), k1(x, y) + v_multi_stage(t + dt, x, y))\n", + " Stage 10: Eq(k00(x, y), k0(x, y))\n", + " Stage 11: Eq(k01(x, y), k1(x, y))\n", + " Stage 12: Eq(k0(x, y), dt*(11.0*dt*(-200*time*dt + 2.0)*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + k01(x, y) + 11.0*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2)) + k00(x, y))\n", + " Stage 13: Eq(k1(x, y), dt*(1.0*dt*(-200*time*dt + 2.0)*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + Derivative(k00(x, y), (x, 2)) + Derivative(k00(x, y), (y, 2))) + k01(x, y))\n", + " Stage 14: Eq(k00(x, y), k0(x, y))\n", + " Stage 15: Eq(k01(x, y), k1(x, y))\n", + " Stage 16: Eq(k0(x, y), dt*(1.0*dt**2*((-200*time*dt + 2.0)**2*exp(-100*(time*dt - 0.01)**2) - 200*exp(-100*(time*dt - 0.01)**2))*src_spat(x, y) + 1.0*dt**2*(10*(-200*time*dt + 2.0)**2*exp(-100*(time*dt - 0.01)**2) - 2000*exp(-100*(time*dt - 0.01)**2))*src_spat(x, y) + 22.0*dt*(-200*time*dt + 2.0)*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + k01(x, y) + 11.0*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2)) + k00(x, y))\n", + " Stage 17: Eq(k1(x, y), dt*(1.0*dt**2*((-200*time*dt + 2.0)**2*exp(-100*(time*dt - 0.01)**2) - 200*exp(-100*(time*dt - 0.01)**2))*src_spat(x, y) + 2.0*dt*(-200*time*dt + 2.0)*src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + src_spat(x, y)*exp(-100*(time*dt - 0.01)**2) + Derivative(k00(x, y), (x, 2)) + Derivative(k00(x, y), (y, 2))) + k01(x, y))\n", + " Stage 18: Eq(u_multi_stage(t + dt, x, y), k0(x, y) + u_multi_stage(t + dt, x, y))\n", + " Stage 19: Eq(v_multi_stage(t + dt, x, y), k1(x, y) + v_multi_stage(t + dt, x, y))\n" + ] + } + ], + "source": [ + "n_eq = 2 # Number of PDE unknowns (u, v)\n", + "\n", + "# Temporary variables for Runge-Kutta stages\n", + "k = [Function(name=f'k{i}', grid=grid, space_order=2, time_order=1, dtype=U_multi_stage[i].dtype) for i in range(n_eq)]\n", + "# Previous stage variables needed for temporary storage\n", + "k_old = [Function(name=f'k0{i}', grid=grid, space_order=2, time_order=1, dtype=U_multi_stage[i].dtype) for i in range(n_eq)]\n", + "\n", + "print(f\"\\nRunge-Kutta temporary variables:\")\n", + "print(f\" k: {[ki.name for ki in k]}\")\n", + "print(f\" k_old: {[ki.name for ki in k_old]}\")\n", + "\n", + "# Initialize expansion coefficients of the source term Taylor's series\n", + "e_p = [0] * deg\n", + "eta = 1\n", + "e_p[-1] = 1 / eta\n", + "print(f\"\\nExpansion coefficients initialized of the source term Taylor's series:\")\n", + "print(f\" e_p = {e_p}\")\n", + "print(f\" eta = {eta}\")\n", + "\n", + "# Initialize SSPRK coefficients (toy example)\n", + "alpha = [1]*4\n", + "print(f\"\\nSSPRK coefficients initialized:\")\n", + "print(f\" alpha = {alpha}\")\n", + "\n", + "\n", + "# Initialize list to store all stage equations\n", + "stage_eqs = []\n", + "\n", + "print(\"Building Runge-Kutta-like scheme...\")\n", + "\n", + "# Stage 0: Initialization\n", + "print(\" Stage 0: Initialization\")\n", + "stage_eqs.extend([Eq(k[i], U_multi_stage[i]) for i in range(n_eq)])\n", + "[stage_eqs.append(Eq(U_multi_stage[i].forward, U_multi_stage[i] * alpha[0])) for i in range(n_eq)]\n", + "\n", + "# Stage 1\n", + "print(\" Stage 1: First RK stage\")\n", + "[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]\n", + "src_lhs, e_p = source_inclusion(system_eqs_rhs, U_multi_stage, k_old, src_index, src_deriv, e_p, t, dt, n_eq)\n", + "[stage_eqs.append(Eq(k[j], k_old[j] + dt * src_lhs[j])) for j in range(n_eq)]\n", + "[stage_eqs.append(Eq(U_multi_stage[j].forward, U_multi_stage[j].forward + k[j] * alpha[1])) for j in range(n_eq)]\n", + "\n", + "# Stage 2\n", + "print(\" Stage 2: Second RK stage\")\n", + "[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]\n", + "src_lhs, e_p = source_inclusion(system_eqs_rhs, U_multi_stage, k_old, src_index, src_deriv, e_p, t, dt, n_eq)\n", + "[stage_eqs.append(Eq(k[j], k_old[j] + dt * src_lhs[j])) for j in range(n_eq)]\n", + "\n", + "# Stage 3 and final update\n", + "print(\" Stage 3: Final RK stage and update\")\n", + "[stage_eqs.append(Eq(k_old[j], k[j])) for j in range(n_eq)]\n", + "src_lhs, _ = source_inclusion(system_eqs_rhs, U_multi_stage, k_old, src_index, src_deriv, e_p, t, dt, n_eq)\n", + "[stage_eqs.append(Eq(k[j], k_old[j] + dt * src_lhs[j])) for j in range(n_eq)]\n", + "[stage_eqs.append(Eq(U_multi_stage[j].forward, U_multi_stage[j].forward + k[j] * alpha[deg - 1])) for j in range(n_eq)]\n", + "\n", + "print(f\"\\nTotal equations created: {len(stage_eqs)}\")\n", + "print(\"Scheme structure:\")\n", + "for i, stage in enumerate(stage_eqs):\n", + " print(f\" Stage {i}: {stage}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "ef2048a1", + "metadata": {}, + "source": [ + "## 6. Create and Run Devito Operator\n", + "\n", + "Compile all the equations into a Devito Operator and execute the simulation with the specified parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d64897d9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating Devito Operator...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Operator `Kernel` generated in 0.33 s\n", + " * lowering.Clusters: 0.17 s (52.7 %)\n", + " * specializing.Clusters: 0.10 s (31.0 %)\n", + " * lowering.IET: 0.11 s (34.1 %)\n", + "Flops reduction after symbolic optimization: [209 --> 106]\n", + "Operator `Kernel` fetched `/tmp/devito-jitcache-uid1000/ead4e1022d510052a8b9b1fb08d861635f2bdbdd.c` in 0.08 s from jit-cache\n", + "Allocating host memory for k0(205, 205) [328 KB]\n", + "Allocating host memory for k00(205, 205) [328 KB]\n", + "Allocating host memory for k01(205, 205) [328 KB]\n", + "Allocating host memory for k1(205, 205) [328 KB]\n", + "Allocating host memory for u_multi_stage(2, 205, 205) [657 KB]\n", + "Allocating host memory for v_multi_stage(2, 205, 205) [657 KB]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Operator successfully created!\n", + " Number of equations: 20\n", + " Grid spacing substitutions applied: {h_x: np.float64(0.005), h_y: np.float64(0.005)}\n", + "\n", + "Simulation parameters:\n", + " Time step (dt): 0.001\n", + " Maximum time: 2000\n", + "\n", + "Running simulation...\n" + ] + }, + { + "ename": "InvalidArgument", + "evalue": "No value found for parameter time_size", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mInvalidArgument\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[11]\u001b[39m\u001b[32m, line 19\u001b[39m\n\u001b[32m 17\u001b[39m \u001b[38;5;66;03m# Execute the simulation\u001b[39;00m\n\u001b[32m 18\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mRunning simulation...\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[43mop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdt\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdt_value\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtime\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtime_max\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mSimulation completed successfully!\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 23\u001b[39m \u001b[38;5;66;03m# Display final wavefield shapes and some statistics\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Trabajo/pos-doc/posdoc_fernando/devito/devito/operator/operator.py:878\u001b[39m, in \u001b[36mOperator.__call__\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 877\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, **kwargs):\n\u001b[32m--> \u001b[39m\u001b[32m878\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Trabajo/pos-doc/posdoc_fernando/devito/devito/operator/operator.py:995\u001b[39m, in \u001b[36mOperator.apply\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 993\u001b[39m \u001b[38;5;66;03m# Build the arguments list to invoke the kernel function\u001b[39;00m\n\u001b[32m 994\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._profiler.timer_on(\u001b[33m'\u001b[39m\u001b[33marguments-preprocess\u001b[39m\u001b[33m'\u001b[39m):\n\u001b[32m--> \u001b[39m\u001b[32m995\u001b[39m args = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43marguments\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 996\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m switch_log_level(comm=args.comm):\n\u001b[32m 997\u001b[39m \u001b[38;5;28mself\u001b[39m._emit_args_profiling(\u001b[33m'\u001b[39m\u001b[33marguments-preprocess\u001b[39m\u001b[33m'\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Trabajo/pos-doc/posdoc_fernando/devito/devito/operator/operator.py:773\u001b[39m, in \u001b[36mOperator.arguments\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 771\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.parameters:\n\u001b[32m 772\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m args.get(p.name) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m773\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m InvalidArgument(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mNo value found for parameter \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mp.name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m args\n", + "\u001b[31mInvalidArgument\u001b[39m: No value found for parameter time_size" + ] + } + ], + "source": [ + "# Create the Devito Operator\n", + "print(\"Creating Devito Operator...\")\n", + "op = Operator(stage_eqs, subs=grid.spacing_map)\n", + "\n", + "print(\"Operator successfully created!\")\n", + "print(f\" Number of equations: {len(stage_eqs)}\")\n", + "print(f\" Grid spacing substitutions applied: {grid.spacing_map}\")\n", + "\n", + "# Define simulation parameters\n", + "dt_value = 0.001 # Time step size\n", + "time_max = 2000 # Maximum simulation time\n", + "\n", + "print(f\"\\nSimulation parameters:\")\n", + "print(f\" Time step (dt): {dt_value}\")\n", + "print(f\" Maximum time: {time_max}\")\n", + "\n", + "# Execute the simulation\n", + "print(\"\\nRunning simulation...\")\n", + "op(dt=dt_value, time=time_max)\n", + "\n", + "print(\"Simulation completed successfully!\")\n", + "\n", + "# Display final wavefield shapes and some statistics\n", + "print(f\"\\nFinal wavefield statistics:\")\n", + "for i, u in enumerate(U_multi_stage):\n", + " print(f\" {fun_labels[i]}:\")\n", + " print(f\" Shape: {u.data.shape}\")\n", + " print(f\" Max value: {np.max(u.data):.6f}\")\n", + " print(f\" Min value: {np.min(u.data):.6f}\")\n", + " print(f\" Mean value: {np.mean(u.data):.6f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4e028d49", + "metadata": {}, + "source": [ + "## 7. References\n", + "\n", + "\n", + "\n", + "[1] Al-Mohy AH, Higham NJ (2010) A new scaling and squaring algorithm for\n", + "the matrix exponential. SIAM Journal on Matrix Analysis and Applications\n", + "31(3):970–989\n", + "\n", + "\n", + "[2] Gottlieb S, Gottlieb LAJ (2003) Strong stability preserving properties of runge–kutta time discretization methods for linear constant coefficient operators. Journal of Scientific Computing 18(1):83–109" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "devito_b", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/devito/ir/equations/algorithms.py b/devito/ir/equations/algorithms.py index ce844887aa..66b44e34f9 100644 --- a/devito/ir/equations/algorithms.py +++ b/devito/ir/equations/algorithms.py @@ -6,14 +6,15 @@ from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered, frozendict) from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension, - ConditionalDimension) + ConditionalDimension, MultiStage) from devito.types.array import Array from devito.types.basic import AbstractFunction from devito.types.dimension import MultiSubDimension, Thickness from devito.data.allocators import DataReference from devito.logger import warning -__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims'] + +__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims'] def dimension_sort(expr): @@ -95,6 +96,39 @@ def handle_indexed(indexed): return ordering +def lower_multistage(expressions, **kwargs): + """ + Separating the multi-stage time-integrator scheme in stages: + * If the object is MultiStage, it creates the stages of the method. + """ + return _lower_multistage(expressions, **kwargs) + + +@singledispatch +def _lower_multistage(expr, **kwargs): + """ + Default handler for expressions that are not MultiStage. + Simply return them in a list. + """ + return [expr] + + +@_lower_multistage.register(MultiStage) +def _(expr, **kwargs): + """ + Specialized handler for MultiStage expressions. + """ + return expr._evaluate(**kwargs) + + +@_lower_multistage.register(Iterable) +def _(exprs, **kwargs): + """ + Handle iterables of expressions. + """ + return sum([_lower_multistage(expr, **kwargs) for expr in exprs], []) + + def lower_exprs(expressions, subs=None, **kwargs): """ Lowering an expression consists of the following passes: diff --git a/devito/operations/solve.py b/devito/operations/solve.py index 0203dbe26d..498ad376f9 100644 --- a/devito/operations/solve.py +++ b/devito/operations/solve.py @@ -7,6 +7,8 @@ from devito.finite_differences.derivative import Derivative from devito.tools import as_tuple +from devito.types.multistage import resolve_method + __all__ = ['solve', 'linsolve'] @@ -56,9 +58,12 @@ def solve(eq, target, **kwargs): # We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions if len(sols) > 1: - return target.new_from_mat(sols) + sols_temp = target.new_from_mat(sols) else: - return sols[0] + sols_temp = sols[0] + + method = kwargs.get("method", None) + return sols_temp if method is None else resolve_method(method)(target, sols_temp) def linsolve(expr, target, **kwargs): diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 0d473fe6a2..7f5b769a77 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -17,7 +17,7 @@ InvalidOperator) from devito.logger import (debug, info, perf, warning, is_log_enabled_for, switch_log_level) -from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims +from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims from devito.ir.clusters import ClusterGroup, clusterize from devito.ir.iet import (Callable, CInterface, EntryFunction, DeviceFunction, FindSymbols, MetaCall, derive_parameters, iet_build) @@ -40,7 +40,6 @@ disk_layer) from devito.types.dimension import Thickness - __all__ = ['Operator'] @@ -337,6 +336,8 @@ def _lower_exprs(cls, expressions, **kwargs): * Apply substitution rules; * Shift indices for domain alignment. """ + expressions = lower_multistage(expressions, **kwargs) + expand = kwargs['options'].get('expand', True) # Specialization is performed on unevaluated expressions diff --git a/devito/types/__init__.py b/devito/types/__init__.py index 6ec8bdfd16..a8c1d3b224 100644 --- a/devito/types/__init__.py +++ b/devito/types/__init__.py @@ -22,3 +22,5 @@ from .relational import * # noqa from .sparse import * # noqa from .tensor import * # noqa + +from .multistage import * # noqa diff --git a/devito/types/multistage.py b/devito/types/multistage.py new file mode 100644 index 0000000000..55ab8f9826 --- /dev/null +++ b/devito/types/multistage.py @@ -0,0 +1,540 @@ +from devito.types.equation import Eq +from devito.types.dense import Function, TimeFunction +from devito.symbolics import uxreplace +import numpy as np +from devito.types.array import Array +from types import MappingProxyType + +method_registry = {} + + +def register_method(cls=None, *, aliases=None): + """ + Register a time integration method class. + + Parameters + ---------- + cls : class, optional + The method class to register. + aliases : list of str, optional + Additional aliases for the method. + """ + def decorator(cls): + # Register the class name + method_registry[cls.__name__] = cls + + # Register any aliases + if aliases: + for alias in aliases: + method_registry[alias] = cls + + return cls + + if cls is None: + # Called as @register_method(aliases=['alias1']) + return decorator + else: + # Called as @register_method + return decorator(cls) + + +def resolve_method(method): + """ + Resolve a time integration method by name. + + Parameters + ---------- + method : str + Name or alias of the time integration method. + + Returns + ------- + class + The method class. + + Raises + ------ + ValueError + If the method is not found in the registry. + """ + try: + return method_registry[method] + except KeyError: + available = sorted(method_registry.keys()) + raise ValueError( + f"The time integrator '{method}' is not implemented. " + f"Available methods: {available}" + ) + + +def multistage_method(lhs, rhs, method, degree=None, source=None): + method_cls = resolve_method(method) + return method_cls(lhs, rhs, degree=degree, source=source) + + +class MultiStage(Eq): + """ + Abstract base class for multi-stage time integration methods + (e.g., Runge-Kutta schemes) in Devito. + + This class represents a symbolic equation of the form `target = rhs` + and provides a mechanism to associate it with a time integration + scheme. The specific integration behavior must be implemented by + subclasses via the `_evaluate` method. + + Parameters + ---------- + lhs : expr-like + The left-hand side of the equation, typically a time-updated Function + (e.g., `u.forward`). + rhs : expr-like, optional + The right-hand side of the equation to integrate. Defaults to 0. + subdomain : SubDomain, optional + A subdomain over which the equation applies. + coefficients : dict, optional + Optional dictionary of symbolic coefficients for the integration. + implicit_dims : tuple, optional + Additional dimensions that should be treated implicitly in the equation. + **kwargs : dict + Additional keyword arguments, such as time integration method selection. + + Notes + ----- + Subclasses must override the `_evaluate()` method to return a sequence + of update expressions for each stage in the integration process. + """ + + def __new__(cls, lhs, rhs, degree=None, source=None, **kwargs): + # Normalize to lists first lhs and rhs + if not isinstance(lhs, (list, tuple)): + lhs = [lhs] + if not isinstance(rhs, (list, tuple)): + rhs = [rhs] + + # Convert to tuples for immutability + lhs_tuple = tuple([i.function for i in lhs]) + rhs_tuple = tuple(rhs) + + obj = super().__new__(cls, lhs_tuple[0], rhs_tuple[0], **kwargs) + + # Store all equations as immutable tuples + obj._eq = tuple(Eq(lhs, rhs) for lhs, rhs in zip(lhs_tuple, rhs_tuple)) + obj._lhs = lhs_tuple + obj._rhs = rhs_tuple + obj._deg = degree + # Convert source to tuple of tuples for immutability + obj._src = tuple(tuple(item) + for item in source) if source is not None else None + obj._t = lhs_tuple[0].grid.time_dim + obj._dt = obj._t.spacing + obj._n_eq = len(lhs_tuple) + + return obj + + @property + def eq(self): + """Return the full tuple of equations.""" + return self._eq + + @property + def lhs(self): + """Return tuple of left-hand sides.""" + return self._lhs + + @property + def rhs(self): + """Return tuple of right-hand sides.""" + return self._rhs + + @property + def deg(self): + """Return the degree parameter.""" + return self._deg + + @property + def src(self): + """Return the source parameter as tuple of tuples (immutable).""" + return self._src + + @property + def t(self): + """Return the time (t) parameter.""" + return self._t + + @property + def dt(self): + """Return the time step (dt) parameter.""" + return self._dt + + @property + def n_eq(self): + """Return the number of equations.""" + return self._n_eq + + def _evaluate(self, **kwargs): + raise NotImplementedError( + f"_evaluate() must be implemented in the subclass {self.__class__.__name__}") + + +class RungeKutta(MultiStage): + """ + Base class for explicit Runge-Kutta (RK) time integration methods defined + via a Butcher tableau. + + This class handles the general structure of RK schemes by using + the Butcher coefficients (`a`, `b`, `c`) to expand a single equation into + a series of intermediate stages followed by a final update. Subclasses + must define `a`, `b`, and `c` as class attributes. + + Parameters + ---------- + a : tuple of tuple of float + The coefficient matrix representing stage dependencies. + b : tuple of float + The weights for the final combination step. + c : tuple of float + The time shifts for each intermediate stage (often the row sums of `a`). + + Attributes + ---------- + a : tuple[tuple[float, ...], ...] + Butcher tableau `a` coefficients (stage coupling). + b : tuple[float, ...] + Butcher tableau `b` coefficients (weights for combining stages). + c : tuple[float, ...] + Butcher tableau `c` coefficients (stage time positions). + s : int + Number of stages in the RK method, inferred from `b`. + """ + + CoeffsBC = tuple[float | np.number, ...] + CoeffsA = tuple[CoeffsBC, ...] + + def __init__(self, a: CoeffsA, b: CoeffsBC, c: CoeffsBC, lhs, rhs, **kwargs) -> None: + self.a, self.b, self.c = a, b, c + + @property + def s(self): + return len(self.b) + + def _evaluate(self, **kwargs): + """ + Generate the stage-wise equations for a Runge-Kutta time integration method. + + This method takes a single equation of the form `Eq(u.forward, rhs)` and + expands it into a sequence of intermediate stage evaluations and a final + update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`. + + Returns + ------- + list of Devito Eq objects + A list of SymPy Eq objects representing: + - `s` stage equations of the form `k_i = rhs evaluated at intermediate state` + - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` + """ + + sregistry = kwargs.get('sregistry') + # Create temporary Arrays to hold each stage + k = [[Array(name=f'{sregistry.make_name(prefix='k')}', dimensions=self.lhs[j].grid.dimensions, grid=self.lhs[j].grid, dtype=self.lhs[j].dtype) for i in range(self.s)] + for j in range(self.n_eq)] + + stage_eqs = [] + + # Build each stage + for i in range(self.s): + u_temp = [self.lhs[l] + self.dt * sum(aij * kj for aij, kj in zip( + self.a[i][:i], k[l][:i])) for l in range(self.n_eq)] + t_shift = self.t + self.c[i] + + # Evaluate RHS at intermediate value + stage_rhs = [uxreplace(self.rhs[l], {**{self.lhs[m]: u_temp[m] for m in range( + self.n_eq)}, self.t: t_shift}) for l in range(self.n_eq)] + stage_eqs.extend([Eq(k[l][i], stage_rhs[l]) + for l in range(self.n_eq)]) + + # Final update: u.forward = u + dt * sum(b_i * k_i) + u_next = [self.lhs[l] + self.dt * + sum(bi * ki for bi, ki in zip(self.b, k[l])) for l in range(self.n_eq)] + stage_eqs.extend([Eq(self.lhs[l].forward, u_next[l]) + for l in range(self.n_eq)]) + + return stage_eqs + + +@register_method(aliases=['RK44']) +class RungeKutta44(RungeKutta): + """ + Classic 4th-order Runge-Kutta (RK4) time integration method. + + This class implements the classic explicit Runge-Kutta method of order 4 (RK44). + + Attributes + ---------- + a : tuple[tuple[float, ...], ...] + Coefficients of the `a` matrix for intermediate stage coupling. + b : tuple[float, ...] + Weights for final combination. + c : tuple[float, ...] + Time positions of intermediate stages. + """ + a = ((0, 0, 0, 0), + (1/2, 0, 0, 0), + (0, 1/2, 0, 0), + (0, 0, 1, 0)) + b = (1/6, 1/3, 1/3, 1/6) + c = (0, 1/2, 1/2, 1) + + def __init__(self, lhs, rhs, **kwargs): + super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs) + + +@register_method(aliases=['RK32']) +class RungeKutta32(RungeKutta): + """ + 3 stages 2nd-order Runge-Kutta (RK32) time integration method. + + This class implements the 3-stages explicit Runge-Kutta method of order 2 (RK32). + + Attributes + ---------- + a : list[list[float]] + Coefficients of the `a` matrix for intermediate stage coupling. + b : list[float] + Weights for final combination. + c : list[float] + Time positions of intermediate stages. + """ + a = ((0, 0, 0), + (1/2, 0, 0), + (0, 1/2, 0)) + b = (0, 0, 1) + c = (0, 1/2, 1/2) + + def __init__(self, lhs, rhs, **kwargs): + super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs) + + +@register_method(aliases=['RK97']) +class RungeKutta97(RungeKutta): + """ + 9 stages 7th-order Runge-Kutta (RK97) time integration method. + + This class implements the 9-stages explicit Runge-Kutta method of order 7 (RK97). + + Attributes + ---------- + a : list[list[float]] + Coefficients of the `a` matrix for intermediate stage coupling. + b : list[float] + Weights for final combination. + c : list[float] + Time positions of intermediate stages. + """ + a = ((0, 0, 0, 0, 0, 0, 0, 0, 0), + (4/63, 0, 0, 0, 0, 0, 0, 0, 0), + (1/42, 1/14, 0, 0, 0, 0, 0, 0, 0), + (1/28, 0, 3/28, 0, 0, 0, 0, 0, 0), + (12551/19652, 0, -48363/19652, 10976/4913, 0, 0, 0, 0, 0), + (-36616931/27869184, 0, 2370277/442368, -255519173 / + 63700992, 226798819/445906944, 0, 0, 0, 0), + (-10401401/7164612, 0, 47383/8748, -4914455 / + 1318761, -1498465/7302393, 2785280/3739203, 0, 0, 0), + (181002080831/17500000000, 0, -14827049601/400000000, 23296401527134463/857600000000000, + 2937811552328081/949760000000000, -243874470411/69355468750, 2857867601589/3200000000000), + (-228380759/19257212, 0, 4828803/113948, -331062132205/10932626912, -12727101935/3720174304, + 22627205314560/4940625496417, -268403949/461033608, 3600000000000/19176750553961)) + b = (95/2366, 0, 0, 3822231133/16579123200, 555164087/2298419200, 1279328256/9538891505, + 5963949/25894400, 50000000000/599799373173, 28487/712800) + c = (0, 4/63, 2/21, 1/7, 7/17, 13/24, 7/9, 91/100, 1) + + def __init__(self, lhs, rhs, **kwargs): + super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs) + + +@register_method(aliases=['HORK_EXP']) +class HighOrderRungeKuttaExponential(MultiStage): + # In construction + """ + n stages Runge-Kutta (HORK) time integration method. + + This class implements the arbitrary high-order explicit Runge-Kutta method. + + Attributes + ---------- + a : list[list[float]] + Coefficients of the `a` matrix for intermediate stage coupling. + b : list[float] + Weights for final combination. + c : list[float] + Time positions of intermediate stages. + """ + + def source_derivatives(self, src_index, **kwargs): + + # Compute the base wavelet function + f_deriv = [[src[1] for src in self.src]] + + # Compute derivatives up to order p + for _ in range(self.deg - 1): + f_deriv.append([deriv.diff(self.t) for deriv in f_deriv[-1]]) + + f_deriv.reverse() + return f_deriv + + def ssprk_alpha(self, mu=1): + """ + Computes the coefficients for the Strong Stability Preserving Runge-Kutta (SSPRK) method. + + Parameters: + mu : float + Theoretically, it should be the inverse of the CFL condition (typically mu=1 for best performance). + In practice, mu=1 works better. + degree : int + Degree of the polynomial used in the time-stepping scheme. + + Returns: + numpy.ndarray + Array of SSPRK coefficients. + """ + + alpha = [0] * self.deg + alpha[0] = 1.0 # Initial coefficient + + # recurrence relation to compute the HORK coefficients following the formula in Gottlieb and Gottlieb (2002) + for i in range(1, self.deg): + alpha[i] = 1 / (mu * (i + 1)) * alpha[i - 1] + alpha[1:i] = [1 / (mu * j) * alpha[j - 1] for j in range(1, i)] + alpha[0] = 1 - sum(alpha[1:i + 1]) + + return alpha + + def source_inclusion(self, current_state, stage_values, e_p, **integration_params): + """ + Include source terms in the time integration step. + + This method applies source term contributions to the right-hand side + of the differential equations during time integration, accounting for + time derivatives of the source function and expansion coefficients. + + Parameters + ---------- + current_state : list + Current state variables (u). + stage_values : list + Current stage values (k). + e_p : list + Expansion coefficients for stability control. + **integration_params : dict + Integration parameters containing 't', 'dt', 'mu', 'src_index', + 'src_deriv', 'n_eq'. + + Returns + ------- + tuple + (modified_rhs, updated_e_p) - Updated right-hand side + equations and modified expansion coefficients. + """ + # Extract integration parameters + mu = integration_params['mu'] + src_index = integration_params['src_index'] + src_deriv = integration_params['src_deriv'] + n_eq = integration_params['n_eq'] + + # Build base right-hand side by substituting current stage values + src_lhs = [uxreplace(self.rhs[i], {current_state[m]: stage_values[m] for m in range(n_eq)}) + for i in range(n_eq)] + + # Apply source term contributions if sources exist + if self.src is not None: + p = len(src_deriv) + + # Add source contributions for each derivative order + for i in range(p): + if e_p[i] != 0: + for j, idx in enumerate(src_index): + # Add weighted source derivative contribution + source_contribution = (self.src[j][0] * src_deriv[i][j].subs({self.t: self.t * self.dt}) * e_p[i]) + src_lhs[idx] += source_contribution + + # Update expansion coefficients for next stage + e_p = [e_p[i] + mu*self.dt*e_p[i + 1] for i in range(p - 1)] + [e_p[-1]] + + return src_lhs, e_p + + def _evaluate(self, **kwargs): + """ + Generate the stage-wise equations for a Runge-Kutta time integration method. + + This method takes a single equation of the form `Eq(u.forward, rhs)` and + expands it into a sequence of intermediate stage evaluations and a final + update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`. + + Returns + ------- + list of Eq + A list of SymPy Eq objects representing: + - `s` stage equations of the form `k_i = rhs evaluated at intermediate state` + - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` + """ + + sregistry = kwargs.get('sregistry') + # Create a temporary Array for each variable to save the time stages + # k = [Array(name=f'{sregistry.make_name(prefix='k')}', dimensions=u[i].grid.dimensions, grid=u[i].grid, dtype=u[i].dtype) for i in range(n_eq)] + k = [TimeFunction(name=f'{sregistry.make_name(prefix='k')}', grid=self.lhs[i].grid, + space_order=2, time_order=1, dtype=self.lhs[i].dtype) for i in range(self.n_eq)] + k_old = [TimeFunction(name=f'{sregistry.make_name(prefix='k')}', grid=self.lhs[i].grid, + space_order=2, time_order=1, dtype=self.lhs[i].dtype) for i in range(self.n_eq)] + + # Compute SSPRK coefficients + mu = 1 + alpha = self.ssprk_alpha(mu=mu) + + # Initialize symbolic differentiation for source terms + field_map = {val: i for i, val in enumerate(self.lhs)} + if self.src is not None: + src_index = [field_map[src[2]] for src in self.src] + src_deriv = self.source_derivatives(src_index, **kwargs) + else: + src_index = None + src_deriv = None + print('src_index:', src_index) + print('src_deriv:', src_deriv) + + # Expansion coefficients for stability control + e_p = [0] * self.deg + eta = 1 + e_p[-1] = 1 / eta + + stage_eqs = [Eq(ki, ui) for ki, ui in zip(k, self.lhs)] + stage_eqs.extend([Eq(lhs_i.forward, lhs_i*alpha[0]) for lhs_i in self.lhs]) + + # Prepare integration parameters for source inclusion + integration_params = {'mu': mu, 'src_index': src_index, + 'src_deriv': src_deriv, 'n_eq': self.n_eq} + + # Build each stage + for i in range(1, self.deg - 1): + print('e_p:', e_p) + stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)]) + src_lhs, e_p = self.source_inclusion(self.lhs, k_old, e_p, **integration_params) + stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)]) + stage_eqs.extend([Eq(lhs_j.forward, lhs_j.forward+k_j*alpha[i]) for lhs_j, k_j in zip(self.lhs, k)]) + print('e_p:', e_p) + # Final Runge-Kutta updates + stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)]) + src_lhs, e_p = self.source_inclusion(self.lhs, k_old, e_p, **integration_params) + stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)]) + print('e_p:', e_p) + stage_eqs.extend([Eq(k_old_j, k_j) for k_old_j, k_j in zip(k_old, k)]) + src_lhs, _ = self.source_inclusion(self.lhs, k_old, e_p, **integration_params) + stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)]) + + # Compute final approximation + stage_eqs.extend([Eq(lhs_j.forward, lhs_j.forward+k_j*alpha[self.deg-1]) for lhs_j, k_j in zip(self.lhs, k)]) + + for i in stage_eqs: + print(i) + + return stage_eqs + +method_registry = MappingProxyType(method_registry) \ No newline at end of file diff --git a/tests/test_multistage.py b/tests/test_multistage.py new file mode 100644 index 0000000000..b47c83d012 --- /dev/null +++ b/tests/test_multistage.py @@ -0,0 +1,503 @@ +import pytest +import numpy as np +import sympy as sym +import tempfile +import pickle +import os + +from devito import (Grid, Function, TimeFunction, + Derivative, Operator, solve, Eq, configuration) +from devito.types.multistage import multistage_method, MultiStage +from devito.ir.support import SymbolRegistry +from devito.ir.equations import lower_multistage + +configuration['log-level'] = 'DEBUG' + + +def grid_parameters(extent=(10, 10), shape=(3, 3)): + grid = Grid(origin=(0, 0), extent=extent, shape=shape, dtype=np.float64) + x, y = grid.dimensions + dt = grid.stepping_dim.spacing + t = grid.time_dim + dx = extent[0] / (shape[0] - 1) + return grid, x, y, dt, t, dx + + +def time_parameters(tn, dx, scale=1, t0=0): + t0, tn = 0.0, tn + dt0 = scale / dx**2 + nt = int((tn - t0) / dt0) + dt0 = tn / nt + return tn, dt0, nt + + +class Test_API: + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_pickles(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system (2D acoustic) + system_eqs_rhs = [u[1] + src_spatial * src_temporal, + Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Class of the time integration scheme + method = multistage_method(u, system_eqs_rhs, time_int) + + with tempfile.NamedTemporaryFile(delete=False) as tmpfile: + pickle.dump(method, tmpfile) + filename = tmpfile.name + + with open(filename, 'rb') as file: + method_saved = pickle.load(file) + os.remove(filename) + + assert str(method) == str( + method_saved), "Mismatch in PDE after pickling" + + op_orig = Operator(method) + op_saved = Operator(method_saved) + + assert str(op_orig) == str(op_saved) + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_solve(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system (2D acoustic) + system_eqs_rhs = [u[1] + src_spatial * src_temporal, + Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Time integration scheme + pdes = [solve(system_eqs_rhs[i] - u[i], u[i], method=time_int) + for i in range(2)] + + assert all(isinstance(i, MultiStage) + for i in pdes), "Not all elements are instances of MultiStage" + + +class Test_lowering: + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_object(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system (2D acoustic) + system_eqs_rhs = [u[1] + src_spatial * src_temporal, + Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Class of the time integration scheme + pdes = multistage_method(u, system_eqs_rhs, time_int) + + assert isinstance( + pdes, MultiStage), "Not all elements are instances of MultiStage" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_lower_multistage(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(1, 1), shape=(3, 3)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system (2D acoustic) + system_eqs_rhs = [u[1] + src_spatial * src_temporal, + Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Class of the time integration scheme + pdes = multistage_method(u, system_eqs_rhs, time_int) + + # Test the lowering process + sregistry = SymbolRegistry() + + # Lower the multistage method - this should not raise an exception + lowered_eqs = lower_multistage(pdes, sregistry=sregistry) + + # Validate the lowered equations + assert lowered_eqs is not None, "Lowering returned None" + assert len(lowered_eqs) > 0, "Lowering returned empty list" + + +class Test_RK: + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_single_equation_integration(self, time_int): + """ + Test single equation time integration with MultiStage methods. + + This test verifies that time integration works correctly for the simplest case: + a single PDE with a single unknown function. This represents the most basic + MultiStage usage scenario (e.g., heat equation, simple wave equation). + """ + + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1, 1), shape=(200, 200)) + + # Define single unknown function + u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, + time_order=1, dtype=np.float64) + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # Single PDE: du/dt = ∇²u + source (diffusion/wave equation) + eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + + Derivative(u_multi_stage, (y, 2), fd_order=2) + + src_spatial * src_temporal) + + # Store initial data for comparison + initial_data = u_multi_stage.data.copy() + + # Time integration scheme - single equation MultiStage object + pde = multistage_method(u_multi_stage, eq_rhs, time_int) + + # Run the operator + op = Operator([pde], subs=grid.spacing_map) # Operator expects a list + op(dt=0.01, time=1) + + # Verify that computation actually occurred (data changed) + assert not np.array_equal( + u_multi_stage.data, initial_data), "Data should have changed" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_decoupled_equations(self, time_int): + """ + Test decoupled time integration where each equation gets its own MultiStage object. + + This test verifies that time integration works when creating separate MultiStage + objects for each equation, as opposed to coupled integration where all equations + are handled by a single MultiStage object. + """ + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1, 1), shape=(200, 200)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u_multi_stage', 'v_multi_stage'] + u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, dtype=np.float64) + for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system - each equation independent for decoupled integration + system_eqs_rhs = [u_multi_stage[1] + src_spatial * src_temporal, + Derivative(u_multi_stage[0], (x, 2), fd_order=2) + + Derivative(u_multi_stage[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] + + # Store initial data for comparison + initial_data = [u.data.copy() for u in u_multi_stage] + + # Time integration scheme - create separate MultiStage objects (decoupled) + pdes = [multistage_method(u_multi_stage[i], system_eqs_rhs[i], time_int) + for i in range(len(fun_labels))] + + # Run the operator + op = Operator(pdes, subs=grid.spacing_map) + op(dt=0.01, time=1) + + # Verify that computation actually occurred (data changed) + for i, u in enumerate(u_multi_stage): + assert not np.array_equal( + u.data, initial_data[i]), f"Data should have changed for variable {i}" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_coupled_op_computing(self, time_int): + """ + Test coupled time integration where all equations are handled by a single MultiStage object. + + This test verifies that time integration works correctly when multiple coupled equations + are integrated together within a single MultiStage object, allowing for proper coupling + between the equations during the time stepping process. + """ + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1, 1), shape=(200, 200)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u_multi_stage', 'v_multi_stage'] + u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[1, 1] = 1 + src_temporal = (1 - 2 * (t * dt - 1)**2) + + # PDE system - coupled acoustic wave equations + system_eqs_rhs = [u_multi_stage[1], # velocity equation: du/dt = v + Derivative(u_multi_stage[0], (x, 2), fd_order=2) + + Derivative(u_multi_stage[0], (y, 2), fd_order=2) + + src_spatial * src_temporal] # displacement equation: dv/dt = ∇²u + source + + # Store initial data for comparison + initial_data = [u.data.copy() for u in u_multi_stage] + + # Time integration scheme - single coupled MultiStage object + pdes = multistage_method(u_multi_stage, system_eqs_rhs, time_int) + + # Run the operator + op = Operator(pdes, subs=grid.spacing_map) + op(dt=0.01, time=1) + + # Verify that computation actually occurred (data changed) + for i, u in enumerate(u_multi_stage): + assert not np.array_equal( + u.data, initial_data[i]), f"Data should have changed for variable {i}" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_low_order_convergence_ODE(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters(extent=(10, 10), shape=(3, 3)) + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[:] = 1 + src_temporal = 2 * t * dt + + # Time axis + tn, dt0, nt = time_parameters(3.0, dx, scale=1e-2) + + # Time integrator solution + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u_multi_stage = [TimeFunction(name=name + '_multi_stage', grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + eq_rhs = [ + (-1.5 * u_multi_stage[0] + 0.5 * u_multi_stage[1]) * src_spatial * src_temporal, + (-1.5 * u_multi_stage[1] + 0.5 * u_multi_stage[0]) * src_spatial * src_temporal] + u_multi_stage[0].data[0, :] = 1 + + # Time integration scheme + pdes = multistage_method(u_multi_stage, eq_rhs, time_int) + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + # exact solution + d = np.array([-1, -2]) + a = np.array([[1, 1], [1, -1]]) + exact_sol = np.dot( + np.dot(a, np.diag(np.exp(d * tn**2))), np.linalg.inv(a)) + assert np.max(np.abs(exact_sol[0, 0] - u_multi_stage[0].data[0, :]) + ) < 10 ** -5, "the method is not converging to the solution" + + @pytest.mark.parametrize('time_int', ['RK44', 'RK32', 'RK97']) + def test_low_order_convergence(self, time_int): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1000, 1000), shape=(201, 201)) + + # Medium velocity model + vel = Function(name=f"vel_{time_int}", + grid=grid, space_order=2, dtype=np.float64) + vel.data[:] = 1.0 + vel.data[150:, :] = 1.3 + + # Source definition + src_spatial = Function( + name=f"src_spat_{time_int}", grid=grid, space_order=2, dtype=np.float64) + src_spatial.data[100, 100] = 1 / dx**2 + f0 = 0.01 + src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1 / f0))**2) * \ + sym.exp(-(np.pi * f0 * (t * dt - 1 / f0))**2) + + # Time axis + tn, dt0, nt = time_parameters(500.0, dx, scale=1e-1 * np.max(vel.data)) + + # Time integrator solution + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u', 'v'] + u_multi_stage = [ + TimeFunction(name=f"{name}_multi_stage_{time_int}", grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + eq_rhs = [u_multi_stage[1], (Derivative(u_multi_stage[0], (x, 2), fd_order=2) + + Derivative(u_multi_stage[0], (y, 2), fd_order=2) + + src_spatial * src_temporal) * vel**2] + + # Time integration scheme + pdes = multistage_method(u_multi_stage, eq_rhs, time_int) + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + # Devito's default solution + u = [TimeFunction(name=f"{name}_{time_int}", grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + eq_rhs = [u[1], (Derivative(u[0], (x, 2), fd_order=2) + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal) * vel**2] + + # Time integration scheme + pdes = [Eq(u[i].forward, solve(Eq(u[i].dt - eq_rhs[i]), u[i].forward)) + for i in range(len(fun_labels))] + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + assert (np.linalg.norm(u[0].data[0, :] - u_multi_stage[0].data[0, :]) / np.linalg.norm( + u[0].data[0, :])) < 10**-1, "the method is not converging to the solution" + + +class Test_HORK: + + def test_coupled_op_computing_exp(self, time_int='HORK_EXP'): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1, 1), shape=(201, 201)) + + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u_multi_stage', 'v_multi_stage'] + u_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=np.float64) + src_spatial.data[100, 100] = 1 + src_temporal = sym.exp(- 100 * (t - 0.01) ** 2) + + # PDE system + system_eqs_rhs = [u_multi_stage[1], + Derivative(u_multi_stage[0], (x, 2), fd_order=2) + + Derivative(u_multi_stage[0], (y, 2), fd_order=2)] + + # Store initial data for comparison + initial_data = [u.data.copy() for u in u_multi_stage] + + src = [[src_spatial, src_temporal, u_multi_stage[0]], + [src_spatial, src_temporal * 10, u_multi_stage[0]], + [src_spatial, src_temporal, u_multi_stage[1]]] + + # Time integration scheme + pdes = multistage_method( + u_multi_stage, system_eqs_rhs, time_int, degree=4, source=src) + op = Operator(pdes, subs=grid.spacing_map) + op(dt=0.001, time=2000) + + # Verify that computation actually occurred (data changed) + for i, u in enumerate(u_multi_stage): + assert not np.array_equal( + u.data, initial_data[i]), f"Data should have changed for variable {i}" + + + @pytest.mark.parametrize('degree', list(range(3, 11))) + def test_HORK_EXP_convergence(self, degree): + # Grid setup + grid, x, y, dt, t, dx = grid_parameters( + extent=(1000, 1000), shape=(201, 201)) + + # Medium velocity model + vel = Function(name="vel", grid=grid, space_order=2, dtype=np.float64) + vel.data[:] = 1.0 + vel.data[150:, :] = 1.3 + + # Source definition + src_spatial = Function(name="src_spat", grid=grid, + space_order=2, dtype=np.float64) + src_spatial.data[100, 100] = 1 / dx**2 + f0 = 0.01 + src_temporal = (1 - 2 * (np.pi * f0 * (t - 1 / f0))**2) * sym.exp(-(np.pi * f0 * (t - 1 / f0))**2) + + # Time axis + tn, dt0, nt = time_parameters(500.0, dx, scale=np.max(vel.data)) + + # Time integrator solution + # Define wavefield unknowns: u (displacement) and v (velocity) + fun_labels = ['u_sol', 'v_sol'] + u_multi_stage = [TimeFunction(name=name + '_multi_stage', grid=grid, space_order=2, time_order=1, + dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + eq_rhs = [u_multi_stage[1], (Derivative(u_multi_stage[0],(x,2), fd_order=2) + Derivative( + u_multi_stage[0], (y,2), fd_order=2)) * vel**2] + + src = [[src_spatial * vel**2, src_temporal, u_multi_stage[1]]] + + # Time integration scheme + pdes = multistage_method( + u_multi_stage, eq_rhs, 'HORK_EXP', source=src, degree=degree) + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + # Devito's default solution + u = [TimeFunction(name=name, grid=grid, space_order=2, + time_order=1, dtype=np.float64) for name in fun_labels] + + # PDE (2D acoustic) + src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1 / f0))**2) * sym.exp(-(np.pi * f0 * (t * dt - 1 / f0))**2) + eq_rhs = [u[1], (Derivative(u[0], (x, 2), fd_order=2) + + Derivative(u[0], (y, 2), fd_order=2) + + src_spatial * src_temporal) * vel**2] + + # Time integration scheme + pdes = [Eq(u[i].forward, solve(Eq(u[i].dt - eq_rhs[i]), u[i].forward)) + for i in range(len(fun_labels))] + op = Operator(pdes, subs=grid.spacing_map) + op(dt=dt0, time=nt) + + assert (np.linalg.norm(u[0].data[0, :] - u_multi_stage[0].data[0, :]) / np.linalg.norm( + u[0].data[0, :])) < 10**-1, "the method is not converging to the solution" \ No newline at end of file diff --git a/tests/test_saving_multistage.pkl b/tests/test_saving_multistage.pkl new file mode 100644 index 0000000000..a88f40675c Binary files /dev/null and b/tests/test_saving_multistage.pkl differ