Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1d830b8
implementation of multi-stage time integrators
fernanvr May 5, 2025
7f087b3
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Jun 13, 2025
214d882
Return of first PR comments
fernanvr Jun 13, 2025
d6c4d4a
Return of first PR comments
fernanvr Jun 13, 2025
78f8a0b
2nd PR revision
fernanvr Jun 23, 2025
1c9d517
2nd PR revision
fernanvr Jun 23, 2025
11db48b
2rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
83dfb04
3rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
d47a106
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
1f93a45
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
eea3a52
5th PR revision, one suggestion from EdC and improving tests
fernanvr Jun 26, 2025
11d1429
including two more Runge-Kutta methods and improving tests: checking …
fernanvr Jul 1, 2025
4637ac2
changes to consider coupled Multistage equations
fernanvr Jul 16, 2025
ac1da7e
Improvements of the HORK_EXP
fernanvr Aug 15, 2025
e9b3533
Merge branch 'main' into multi-stage-time-integrator
fernanvr Aug 15, 2025
dc3dd77
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Oct 8, 2025
1fd4a02
tuples, improved class names, extensive tests
fernanvr Oct 8, 2025
a0c45c1
improving spacing in some tests
fernanvr Oct 8, 2025
93c6e3f
Add MFE time stepping Jupyter notebook
fernanvr Oct 23, 2025
ef8d1ac
Remove MFE_time_size.ipynb notebook
fernanvr Oct 23, 2025
fa5acac
Update multistage implementation and tests
fernanvr Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
637 changes: 637 additions & 0 deletions MFE_time_size.ipynb

Large diffs are not rendered by default.

38 changes: 36 additions & 2 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered,
frozendict)
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
ConditionalDimension)
ConditionalDimension, MultiStage)
from devito.types.array import Array
from devito.types.basic import AbstractFunction
from devito.types.dimension import MultiSubDimension, Thickness
from devito.data.allocators import DataReference
from devito.logger import warning

__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims']

__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims']


def dimension_sort(expr):
Expand Down Expand Up @@ -95,6 +96,39 @@ def handle_indexed(indexed):
return ordering


def lower_multistage(expressions, **kwargs):
"""
Separating the multi-stage time-integrator scheme in stages:
* If the object is MultiStage, it creates the stages of the method.
"""
return _lower_multistage(expressions, **kwargs)


@singledispatch
def _lower_multistage(expr, **kwargs):
"""
Default handler for expressions that are not MultiStage.
Simply return them in a list.
"""
return [expr]


@_lower_multistage.register(MultiStage)
def _(expr, **kwargs):
"""
Specialized handler for MultiStage expressions.
"""
return expr._evaluate(**kwargs)


@_lower_multistage.register(Iterable)
def _(exprs, **kwargs):
"""
Handle iterables of expressions.
"""
return sum([_lower_multistage(expr, **kwargs) for expr in exprs], [])


def lower_exprs(expressions, subs=None, **kwargs):
"""
Lowering an expression consists of the following passes:
Expand Down
9 changes: 7 additions & 2 deletions devito/operations/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from devito.finite_differences.derivative import Derivative
from devito.tools import as_tuple

from devito.types.multistage import resolve_method

__all__ = ['solve', 'linsolve']


Expand Down Expand Up @@ -56,9 +58,12 @@ def solve(eq, target, **kwargs):

# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions
if len(sols) > 1:
return target.new_from_mat(sols)
sols_temp = target.new_from_mat(sols)
else:
return sols[0]
sols_temp = sols[0]

method = kwargs.get("method", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the method_registry mapper. Furthermore, it would allow you to have method.resolve(target, sols_temp) here, which is tidier

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a string. The idea is that the user provides a string to identify which time integrator to apply.

return sols_temp if method is None else resolve_method(method)(target, sols_temp)


def linsolve(expr, target, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
InvalidOperator)
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
switch_log_level)
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims
from devito.ir.clusters import ClusterGroup, clusterize
from devito.ir.iet import (Callable, CInterface, EntryFunction, DeviceFunction,
FindSymbols, MetaCall, derive_parameters, iet_build)
Expand All @@ -40,7 +40,6 @@
disk_layer)
from devito.types.dimension import Thickness


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run the linter (flake8) 🙂

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

__all__ = ['Operator']


Expand Down Expand Up @@ -337,6 +336,8 @@ def _lower_exprs(cls, expressions, **kwargs):
* Apply substitution rules;
* Shift indices for domain alignment.
"""
expressions = lower_multistage(expressions, **kwargs)

expand = kwargs['options'].get('expand', True)

# Specialization is performed on unevaluated expressions
Expand Down
2 changes: 2 additions & 0 deletions devito/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@
from .relational import * # noqa
from .sparse import * # noqa
from .tensor import * # noqa

from .multistage import * # noqa
Loading