Skip to content

Commit dc44841

Browse files
committed
abi: move mpi_type_get_envelope etc. into templates
MPI_Type_get_envelop and MPI_Type_get_contents are a big mess as the MPI Forum decided not just to add big variants but also add additional arguments for the big count variants. So this necessitated enhancements to the binding infrastructure to support optional suppressing of bc and non-bc variants of prototype files. Signed-off-by: Howard Pritchard <howardp@lanl.gov>
1 parent 1c63109 commit dc44841

File tree

10 files changed

+399
-56
lines changed

10 files changed

+399
-56
lines changed

ompi/mpi/bindings/bindings.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def main():
6161
# parser = argparse.ArgumentParser(description='C ABI binding generation code')
6262
parser_gen.add_argument('type', choices=('ompi', 'standard'),
6363
help='generate the OMPI ABI functions or the standard ABI functions')
64-
parser_gen.add_argument('--mpit', action='store_true', help='generate MPI T code generation')
64+
parser_gen.add_argument('--mpit', action='store_true', help='generate MPI T code')
65+
parser_gen.add_argument('--suppress_bc', action='store_true', help='do not generate big count variant')
66+
parser_gen.add_argument('--suppress_nbc', action='store_true', help='do not generate int count variant')
6567
parser_gen.add_argument('source_file', help='template file to use for C code generation')
6668
parser_gen.set_defaults(handler=lambda args, out: c.generate_source(args, out))
6769
args = parser.parse_args()

ompi/mpi/bindings/ompi_bindings/c.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -416,17 +416,22 @@ def print_cdefs_for_abi(out, abi_type='ompi'):
416416
out.dump('#undef OMPI_ABI_SRC')
417417
out.dump('#define OMPI_ABI_SRC 1')
418418

419-
def ompi_abi(base_name, template, out):
419+
def ompi_abi(base_name, template, out, suppress_bc=False, suppress_nbc=False):
420420
"""Generate the OMPI ABI functions."""
421421
template.print_header(out)
422-
print_profiling_header(base_name, out)
423-
print_cdefs_for_bigcount(out)
424-
print_cdefs_for_abi(out)
425-
out.dump(template.prototype.signature(base_name, abi_type='ompi'))
426-
template.print_body(func_name=base_name, out=out)
422+
if suppress_nbc == False:
423+
print_profiling_header(base_name, out)
424+
print_cdefs_for_bigcount(out)
425+
print_cdefs_for_abi(out)
426+
out.dump(template.prototype.signature(base_name, abi_type='ompi'))
427+
template.print_body(func_name=base_name, out=out)
427428
# Check if we need to generate the bigcount interface
428-
if util.prototype_has_bigcount(template.prototype):
429-
base_name_c = f'{base_name}_c'
429+
if util.prototype_has_bigcount(template.prototype) and suppress_bc == False:
430+
# there are some special cases where we need to explicitly define the bigcount functions in the template file
431+
if base_name[-2:] == "_c":
432+
base_name_c = f'{base_name}'
433+
else:
434+
base_name_c = f'{base_name}_c'
430435
print_profiling_header(base_name_c, out)
431436
print_cdefs_for_bigcount(out, enable_count=True)
432437
print_cdefs_for_abi(out)
@@ -438,7 +443,7 @@ def ompi_abi(base_name, template, out):
438443
ABI_INTERNAL_CONVERTOR = 'ompi/mpi/c/abi_converters.h'
439444

440445

441-
def standard_abi(base_name, template, out):
446+
def standard_abi(base_name, template, out, suppress_bc=False, suppress_nbc=False):
442447
"""Generate the standard ABI functions."""
443448
template.print_header(out)
444449
out.dump(f'#include "{ABI_INTERNAL_HEADER}"')
@@ -457,14 +462,15 @@ def standard_abi(base_name, template, out):
457462
out.dump(line)
458463

459464
# Static internal function (add a random component to avoid conflicts)
460-
internal_name = f'ompi_abi_{template.prototype.name}'
461-
print_cdefs_for_bigcount(out)
462-
print_cdefs_for_abi(out, abi_type='standard')
463-
internal_sig = template.prototype.signature(internal_name, abi_type='ompi',
464-
enable_count=False)
465-
out.dump(consts.INLINE_ATTRS, internal_sig)
466-
template.print_body(func_name=base_name, out=out)
467-
if util.prototype_has_bigcount(template.prototype):
465+
if suppress_nbc == False:
466+
internal_name = f'ompi_abi_{template.prototype.name}'
467+
print_cdefs_for_bigcount(out)
468+
print_cdefs_for_abi(out, abi_type='standard')
469+
internal_sig = template.prototype.signature(internal_name, abi_type='ompi',
470+
enable_count=False)
471+
out.dump(consts.INLINE_ATTRS, internal_sig)
472+
template.print_body(func_name=base_name, out=out)
473+
if util.prototype_has_bigcount(template.prototype) and suppress_bc == False:
468474
internal_name = f'ompi_abi_{template.prototype.name}_c'
469475
print_cdefs_for_bigcount(out, enable_count=True)
470476
print_cdefs_for_abi(out, abi_type='standard')
@@ -502,10 +508,14 @@ def generate_function(prototype, fn_name, internal_fn, out, enable_count=False):
502508
out.dump(line)
503509
out.dump('}')
504510

505-
internal_name = f'ompi_abi_{template.prototype.name}'
506-
generate_function(template.prototype, base_name, internal_name, out)
507-
if util.prototype_has_bigcount(template.prototype):
508-
base_name_c = f'{base_name}_c'
511+
if suppress_nbc == False:
512+
internal_name = f'ompi_abi_{template.prototype.name}'
513+
generate_function(template.prototype, base_name, internal_name, out)
514+
if util.prototype_has_bigcount(template.prototype) and suppress_bc == False:
515+
if base_name[-2:] == "_c":
516+
base_name_c = f'{base_name}'
517+
else:
518+
base_name_c = f'{base_name}_c'
509519
internal_name = f'ompi_abi_{template.prototype.name}_c'
510520
generate_function(template.prototype, base_name_c, internal_name, out,
511521
enable_count=True)
@@ -529,6 +539,6 @@ def generate_source(args, out):
529539
else:
530540
base_name = util.mpi_fn_name_from_base_fn_name(template.prototype.name)
531541
if args.type == 'ompi':
532-
ompi_abi(base_name, template, out)
542+
ompi_abi(base_name, template, out, args.suppress_bc, args.suppress_nbc)
533543
else:
534-
standard_abi(base_name, template, out)
544+
standard_abi(base_name, template, out, args.suppress_bc, args.suppress_nbc)

ompi/mpi/bindings/ompi_bindings/c_type.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ def parameter(self, enable_count=False, **kwargs):
139139
count_type = 'MPI_Count' if enable_count else 'int'
140140
return f'const {count_type} {self.name}[]'
141141

142+
@Type.add_type('COUNT_ARRAY_OUT')
143+
class TypeCountArrayOut(TypeCountArray):
144+
"""Array of counts out (either int or MPI_Count)."""
145+
146+
def parameter(self, enable_count=False, **kwargs):
147+
count_type = 'MPI_Count' if enable_count else 'int'
148+
return f'{count_type} {self.name}[]'
149+
142150
@Type.add_type('AINT_COUNT_ARRAY')
143151
class TypeAintCountArray(Type):
144152
"""Array of counts (either MPI_Aint or MPI_Count)."""
@@ -154,6 +162,14 @@ def parameter(self, enable_count=False, **kwargs):
154162
count_type = 'MPI_Count' if enable_count else 'MPI_Aint'
155163
return f'const {count_type} {self.name}[]'
156164

165+
@Type.add_type('AINT_COUNT_ARRAY_OUT')
166+
class TypeAintCountArrayOut(TypeAintCountArray):
167+
"""Array of counts (either MPI_Aint or MPI_Count)."""
168+
169+
def parameter(self, enable_count=False, **kwargs):
170+
count_type = 'MPI_Count' if enable_count else 'MPI_Aint'
171+
return f'{count_type} {self.name}[]'
172+
157173
@Type.add_type('ELEMENT_COUNT')
158174
class ElementCountType(Type):
159175
"""Special count type for MPI_Get_element_x"""
@@ -226,6 +242,11 @@ def type_text(self, enable_count=False):
226242
def parameter(self, enable_count=False, **kwargs):
227243
return f'const MPI_Aint {self.name}[]'
228244

245+
@Type.add_type('AINT_ARRAY_OUT')
246+
class TypeAintArrayOut(TypeAintArray):
247+
248+
def parameter(self, enable_count=False, **kwargs):
249+
return f'MPI_Aint {self.name}[]'
229250

230251
@Type.add_type('INT_OUT')
231252
class TypeIntOut(Type):
@@ -282,6 +303,15 @@ def type_text(self, enable_count=False):
282303
def parameter(self, enable_count=False, **kwargs):
283304
return f'const int {self.name}[]'
284305

306+
@Type.add_type('INT_ARRAY_OUT')
307+
class TypeIntArrayOut(TypeIntArray):
308+
309+
def type_text(self, enable_count=False):
310+
return 'int *'
311+
312+
def parameter(self, enable_count=False, **kwargs):
313+
return f'int {self.name}[]'
314+
285315
@Type.add_type('INT_AINT_OUT')
286316
class TypeIntAintOut(Type):
287317

@@ -362,6 +392,14 @@ def type_text(self, enable_count=False):
362392
def parameter(self, enable_count=False, **kwargs):
363393
return f'const {self.type_text(enable_count=enable_count)} {self.name}[]'
364394

395+
@Type.add_type('DATATYPE_ARRAY_OUT', abi_type=['ompi'])
396+
class TypeDatatypeArrayOut(Type):
397+
398+
def type_text(self, enable_count=False):
399+
return 'MPI_Datatype'
400+
401+
def parameter(self, enable_count=False, **kwargs):
402+
return f'{self.type_text(enable_count=enable_count)} {self.name}[]'
365403

366404
class StandardABIType(Type):
367405

@@ -406,9 +444,6 @@ def type_text(self, enable_count=False):
406444
def argument(self):
407445
return f'(MPI_Datatype *) {self.name}'
408446

409-
#
410-
# TODO THIS IS NOT COMPLETE
411-
#
412447
@Type.add_type('DATATYPE_ARRAY', abi_type=['standard'])
413448
class TypeDatatypeArrayStandard(StandardABIType):
414449

@@ -444,6 +479,38 @@ def parameter(self, enable_count=False, **kwargs):
444479
def argument(self):
445480
return f'(MPI_Datatype *) {self.tmpname}'
446481

482+
@Type.add_type('DATATYPE_ARRAY_OUT', abi_type=['standard'])
483+
class TypeDatatypeArrayOutStandard(StandardABIType):
484+
485+
@property
486+
def init_code(self):
487+
code = [f'int size_{self.tmpname} = {self.count_param};']
488+
code.append(f'MPI_Datatype *{self.tmpname} = (MPI_Datatype *)malloc({self.count_param} * sizeof(MPI_Datatype));')
489+
return code
490+
491+
@property
492+
def final_code(self):
493+
code = [f'for(int i=0;i<size_{self.tmpname};i++){{']
494+
code.append(f'{self.name}[i] = {ConvertOMPIToStandard.DATATYPE}({self.tmpname}[i]);')
495+
code.append(f'}}')
496+
code.append(f'free({self.tmpname});')
497+
return code
498+
499+
@property
500+
def tmpname(self):
501+
return f'{self.name}_tmp'
502+
503+
def type_text(self, enable_count=False):
504+
return self.mangle_name('MPI_Datatype')
505+
506+
def parameter(self, enable_count=False, **kwargs):
507+
return f'{self.type_text(enable_count=enable_count)} {self.name}[]'
508+
509+
@property
510+
def argument(self):
511+
return f'(MPI_Datatype *) {self.tmpname}'
512+
513+
447514
@Type.add_type('OP', abi_type=['ompi'])
448515
class TypeDatatype(Type):
449516

ompi/mpi/bindings/ompi_bindings/parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# $HEADER$
99

1010
import os
11+
import sys
1112

1213
"""Source parsing code."""
1314

@@ -16,7 +17,11 @@ class Parameter:
1617
def __init__(self, text, type_constructor):
1718
"""Parse a parameter."""
1819
# parameter in the form "TYPE NAME" or "TYPE NAME:COUNT_VAR"
19-
type_name, namecount = text.split()
20+
try:
21+
type_name, namecount = text.split()
22+
except Exception as e:
23+
print(f"Error: could not split '{text}' got error {e}")
24+
sys.exit(-1)
2025
if ':' in namecount:
2126
name, count_param = namecount.split(':')
2227
else:

ompi/mpi/c/Makefile.am

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,14 @@ prototype_sources = \
470470
win_wait.c.in \
471471
wtime.c.in
472472

473+
prototype_sources_nbc = \
474+
type_get_contents.c.in_nbc \
475+
type_get_envelope.c.in_nbc
476+
477+
prototype_sources_obc = \
478+
type_get_contents_c.c.in_obc \
479+
type_get_envelope_c.c.in_obc
480+
473481
# See MPI-5 standard Chapter 20 section 4
474482
prototype_sources_not_in_abi = \
475483
comm_c2f.c.in \
@@ -502,6 +510,8 @@ prototype_sources_not_in_abi = \
502510
win_f2c.c.in
503511

504512
EXTRA_DIST = $(prototype_sources) \
513+
$(prototype_sources_nbc) \
514+
$(prototype_sources_obc) \
505515
$(prototype_sources_not_in_abi) \
506516
abi_converters.h \
507517
abi_get_info.c.in
@@ -529,8 +539,9 @@ nobase_include_HEADERS = abi.h standard_abi/mpi.h
529539
#
530540
#
531541
interface_profile_sources = $(prototype_sources:.c.in=_ompi_generated.c) \
532-
$(prototype_sources_not_in_abi:.c.in=_ompi_generated.c)
533-
542+
$(prototype_sources_not_in_abi:.c.in=_ompi_generated.c) \
543+
$(prototype_sources_nbc:.c.in_nbc=_ompi_generated.c) \
544+
$(prototype_sources_obc:.c.in_obc=_ompi_generated.c)
534545

535546
# Conditionally install the header files
536547
if WANT_INSTALL_HEADERS
@@ -543,10 +554,6 @@ endif
543554
#
544555
extra_interface_profile_sources = \
545556
pcontrol.c \
546-
type_get_contents.c \
547-
type_get_contents_c.c \
548-
type_get_envelope.c \
549-
type_get_envelope_c.c \
550557
wtick.c
551558

552559
# The following functions were removed from the MPI standard, but are
@@ -584,6 +591,51 @@ if OMPI_GENERATE_BINDINGS
584591
ompi \
585592
$<
586593

594+
# Deal with oddballs wrt big count
595+
type_get_contents_ompi_generated.c: type_get_contents.c.in_nbc
596+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
597+
--builddir $(abs_top_builddir) \
598+
--srcdir $(abs_top_srcdir) \
599+
--output $@ \
600+
c \
601+
source \
602+
ompi \
603+
--suppress_bc \
604+
$<
605+
606+
type_get_envelope_ompi_generated.c: type_get_envelope.c.in_nbc
607+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
608+
--builddir $(abs_top_builddir) \
609+
--srcdir $(abs_top_srcdir) \
610+
--output $@ \
611+
c \
612+
source \
613+
ompi \
614+
--suppress_bc \
615+
$<
616+
617+
type_get_contents_c_ompi_generated.c: type_get_contents_c.c.in_obc
618+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
619+
--builddir $(abs_top_builddir) \
620+
--srcdir $(abs_top_srcdir) \
621+
--output $@ \
622+
c \
623+
source \
624+
ompi \
625+
--suppress_nbc \
626+
$<
627+
628+
type_get_envelope_c_ompi_generated.c: type_get_envelope_c.c.in_obc
629+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
630+
--builddir $(abs_top_builddir) \
631+
--srcdir $(abs_top_srcdir) \
632+
--output $@ \
633+
c \
634+
source \
635+
ompi \
636+
--suppress_nbc \
637+
$<
638+
587639
# Non-mangled version
588640
standard_abi/mpi.h: $(top_srcdir)/docs/mpi-standard-apis.json $(top_srcdir)/ompi/mpi/bindings/c_header.py
589641
mkdir -p standard_abi
@@ -611,6 +663,51 @@ abi.h: $(top_srcdir)/docs/mpi-standard-apis.json $(top_srcdir)/ompi/mpi/bindings
611663
source \
612664
standard \
613665
$<
666+
667+
# Deal with oddballs wrt big count
668+
type_get_contents_abi_generated.c: type_get_contents.c.in_nbc
669+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
670+
--builddir $(abs_top_builddir) \
671+
--srcdir $(abs_top_srcdir) \
672+
--output $@ \
673+
c \
674+
source \
675+
standard \
676+
--suppress_bc \
677+
$<
678+
679+
type_get_envelope_abi_generated.c: type_get_envelope.c.in_nbc
680+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
681+
--builddir $(abs_top_builddir) \
682+
--srcdir $(abs_top_srcdir) \
683+
--output $@ \
684+
c \
685+
source \
686+
standard \
687+
--suppress_bc \
688+
$<
689+
690+
type_get_contents_c_abi_generated.c: type_get_contents_c.c.in_obc
691+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
692+
--builddir $(abs_top_builddir) \
693+
--srcdir $(abs_top_srcdir) \
694+
--output $@ \
695+
c \
696+
source \
697+
standard \
698+
--suppress_nbc \
699+
$<
700+
701+
type_get_envelope_c_abi_generated.c: type_get_envelope_c.c.in_obc
702+
$(OMPI_V_GEN) $(PYTHON) $(top_srcdir)/ompi/mpi/bindings/bindings.py \
703+
--builddir $(abs_top_builddir) \
704+
--srcdir $(abs_top_srcdir) \
705+
--output $@ \
706+
c \
707+
source \
708+
standard \
709+
--suppress_nbc \
710+
$<
614711
endif
615712

616713
MAINTAINERCLEANFILES = *_generated.c abi_get_info.c $(nobase_include_HEADERS)

0 commit comments

Comments
 (0)