diff --git a/grudge/array_context.py b/grudge/array_context.py index c5672178d..3348c25af 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -384,9 +384,11 @@ def to_output_template(keys, _): class MPIPytatoArrayContextBase(MPIBasedArrayContext): def __init__( - self, mpi_communicator, queue, *, mpi_base_tag, allocator=None, - compile_trace_callback: Optional[Callable[[Any, str, Any], None]] - = None) -> None: + self, mpi_communicator, queue, *, + mpi_base_tag, allocator=None, + compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None, + use_axis_tag_inference_fallback: bool = False, + use_einsum_inference_fallback: bool = False) -> None: """ :arg compile_trace_callback: A function of three arguments *(what, stage, ir)*, where *what* identifies the object @@ -401,7 +403,9 @@ def __init__( "to reduce device allocations)") super().__init__(queue, allocator, - compile_trace_callback=compile_trace_callback) + compile_trace_callback=compile_trace_callback, + use_axis_tag_inference_fallback=use_axis_tag_inference_fallback, + use_einsum_inference_fallback=use_einsum_inference_fallback) self.mpi_communicator = mpi_communicator self.mpi_base_tag = mpi_base_tag @@ -416,7 +420,9 @@ def clone(self): # pylint: disable=no-member return type(self)(self.mpi_communicator, self.queue, mpi_base_tag=self.mpi_base_tag, - allocator=self.allocator) + allocator=self.allocator, + use_axis_tag_inference_fallback=self.use_axis_tag_inference_fallback, + use_einsum_inference_fallback=self.use_einsum_inference_fallback) # }}}