diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index c7f32a4..8f7a79e 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -22,8 +22,10 @@ jobs: run: python -m pip install -r requirements.txt --user - name: Perform editable installation to generate the schema subpackage run: python -m pip install -e . - - name: Run all tests + - name: Run library tests run: python -m pytest + - name: Run end-to-end tests + run: bash tests/end-to-end/test-reconstruction.sh - name: Install pypa/build run: python -m pip install build --user - name: Build a binary wheel and a source tarball diff --git a/.gitignore b/.gitignore index b02fc1e..b69ebc8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *~ *.h5 +*.ismrmrd *.pyc MANIFEST build/ diff --git a/README.md b/README.md index ec1430c..b9cc0c5 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -Python implementation of the ISMRMRD +Python implementation of [ISMRMRD](https://github.com/ismrmrd/ismrmrd) diff --git a/examples/demo.py b/examples/demo.py deleted file mode 100644 index bcc3d69..0000000 --- a/examples/demo.py +++ /dev/null @@ -1,11 +0,0 @@ -import ismrmrd - -acq = ismrmrd.Acquisition() -acq.version = 42 -print(acq.version) - -img = ismrmrd.Image() - -f = ismrmrd.Dataset('./testdata.h5', '/dataset', True) -print( f._file) -# xml = f.readHeader() diff --git a/examples/stream_recon.py b/examples/stream_recon.py new file mode 100644 index 0000000..1ecca8e --- /dev/null +++ b/examples/stream_recon.py @@ -0,0 +1,233 @@ +import sys +import argparse +import numpy as np +from typing import BinaryIO, Iterable, Union + +from ismrmrd import Acquisition, Image, ImageHeader, ProtocolDeserializer, ProtocolSerializer +from ismrmrd.xsd import ismrmrdHeader +from ismrmrd.constants import ACQ_IS_NOISE_MEASUREMENT, IMTYPE_MAGNITUDE +from ismrmrd.serialization import SerializableObject + +from numpy.fft import fftshift, ifftshift, fftn, ifftn + + +def kspace_to_image(k: np.ndarray, dim=None, img_shape=None) -> np.ndarray: + """ Computes the Fourier transform from k-space to image space + along a given or all dimensions + + :param k: k-space data + :param dim: vector of dimensions to transform + :param img_shape: desired shape of output image + :returns: data in image space (along transformed dimensions) + """ + if not dim: + dim = range(k.ndim) + img = fftshift(ifftn(ifftshift(k, axes=dim), s=img_shape, axes=dim), axes=dim) + img *= np.sqrt(np.prod(np.take(img.shape, dim))) + return img + + +def image_to_kspace(img: np.ndarray, dim=None, k_shape=None) -> np.ndarray: + """ Computes the Fourier transform from image space to k-space space + along a given or all dimensions + + :param img: image space data + :param dim: vector of dimensions to transform + :param k_shape: desired shape of output k-space data + :returns: data in k-space (along transformed dimensions) + """ + if not dim: + dim = range(img.ndim) + k = fftshift(fftn(ifftshift(img, axes=dim), s=k_shape, axes=dim), axes=dim) + k /= np.sqrt(np.prod(np.take(img.shape, dim))) + return k + + +def acquisition_reader(input: Iterable[SerializableObject]) -> Iterable[Acquisition]: + for item in input: + if not isinstance(item, Acquisition): + # Skip non-acquisition items + continue + if item.flags & ACQ_IS_NOISE_MEASUREMENT: + # Currently ignoring noise scans + continue + yield item + +def stream_item_sink(input: Iterable[Union[Acquisition, Image]]) -> Iterable[SerializableObject]: + for item in input: + if isinstance(item, Acquisition): + yield item + elif isinstance(item, Image) and item.data.dtype == np.float32: + yield item + else: + raise ValueError("Unknown item type") + +def remove_oversampling(head: ismrmrdHeader, input: Iterable[Acquisition]) -> Iterable[Acquisition]: + enc = head.encoding[0] + + if enc.encodedSpace and enc.encodedSpace.matrixSize and enc.reconSpace and enc.reconSpace.matrixSize: + eNx = enc.encodedSpace.matrixSize.x + rNx = enc.reconSpace.matrixSize.x + else: + raise Exception('Encoding information missing from header') + + for acq in input: + if eNx != rNx and acq.number_of_samples == eNx: + xline = kspace_to_image(acq.data, [1]) + x0 = (eNx - rNx) // 2 + x1 = x0 + rNx + xline = xline[:, x0:x1] + head = acq.getHead() + head.center_sample = rNx // 2 + data = image_to_kspace(xline, [1]) + acq = Acquisition(head, data) + yield acq + +def accumulate_fft(head: ismrmrdHeader, input: Iterable[Acquisition]) -> Iterable[Image]: + enc = head.encoding[0] + + # Matrix size + if enc.encodedSpace and enc.reconSpace and enc.encodedSpace.matrixSize and enc.reconSpace.matrixSize: + eNx = enc.encodedSpace.matrixSize.x + eNy = enc.encodedSpace.matrixSize.y + eNz = enc.encodedSpace.matrixSize.z + rNx = enc.reconSpace.matrixSize.x + rNy = enc.reconSpace.matrixSize.y + rNz = enc.reconSpace.matrixSize.z + else: + raise Exception('Required encoding information not found in header') + + # Field of view + if enc.reconSpace and enc.reconSpace.fieldOfView_mm: + rFOVx = enc.reconSpace.fieldOfView_mm.x + rFOVy = enc.reconSpace.fieldOfView_mm.y + rFOVz = enc.reconSpace.fieldOfView_mm.z if enc.reconSpace.fieldOfView_mm.z else 1 + else: + raise Exception('Required field of view information not found in header') + + # Number of Slices, Reps, Contrasts, etc. + ncoils = 1 + if head.acquisitionSystemInformation and head.acquisitionSystemInformation.receiverChannels: + ncoils = head.acquisitionSystemInformation.receiverChannels + + nslices = 1 + if enc.encodingLimits and enc.encodingLimits.slice != None: + nslices = enc.encodingLimits.slice.maximum + 1 + + ncontrasts = 1 + if enc.encodingLimits and enc.encodingLimits.contrast != None: + ncontrasts = enc.encodingLimits.contrast.maximum + 1 + + ky_offset = 0 + if enc.encodingLimits and enc.encodingLimits.kspace_encoding_step_1 != None: + ky_offset = int((eNy+1)/2) - enc.encodingLimits.kspace_encoding_step_1.center + + current_rep = -1 + reference_acquisition = None + buffer = None + image_index = 0 + + def produce_image(buffer: np.ndarray, ref_acq: Acquisition) -> Iterable[Image]: + nonlocal image_index + + if buffer.shape[-3] > 1: + img = kspace_to_image(buffer, dim=[-1, -2, -3]) + else: + img = kspace_to_image(buffer, dim=[-1, -2]) + + for contrast in range(img.shape[0]): + for islice in range(img.shape[1]): + slice = img[contrast, islice] + combined = np.squeeze(np.sqrt(np.abs(np.sum(slice * np.conj(slice), axis=0)).astype('float32'))) + + xoffset = (combined.shape[-1] + 1) // 2 - (rNx+1) // 2 + yoffset = (combined.shape[-2] + 1) // 2 - (rNy+1) // 2 + if len(combined.shape) == 3: + zoffset = (combined.shape[-3] + 1) // 2 - (rNz+1) // 2 + combined = combined[zoffset:(zoffset+rNz), yoffset:(yoffset+rNy), xoffset:(xoffset+rNx)] + combined = np.reshape(combined, (1, combined.shape[-3], combined.shape[-2], combined.shape[-1])) + elif len(combined.shape) == 2: + combined = combined[yoffset:(yoffset+rNy), xoffset:(xoffset+rNx)] + combined = np.reshape(combined, (1, 1, combined.shape[-2], combined.shape[-1])) + else: + raise Exception('Array img_combined should have 2 or 3 dimensions') + + imghdr = ImageHeader(image_type=IMTYPE_MAGNITUDE) + imghdr.version = 1 + imghdr.measurement_uid = ref_acq.measurement_uid + imghdr.field_of_view[0] = rFOVx + imghdr.field_of_view[1] = rFOVy + imghdr.field_of_view[2] = rFOVz/rNz + imghdr.position = ref_acq.position + imghdr.read_dir = ref_acq.read_dir + imghdr.phase_dir = ref_acq.phase_dir + imghdr.slice_dir = ref_acq.slice_dir + imghdr.patient_table_position = ref_acq.patient_table_position + imghdr.average = ref_acq.idx.average + imghdr.slice = ref_acq.idx.slice + imghdr.contrast = contrast + imghdr.phase = ref_acq.idx.phase + imghdr.repetition = ref_acq.idx.repetition + imghdr.set = ref_acq.idx.set + imghdr.acquisition_time_stamp = ref_acq.acquisition_time_stamp + imghdr.physiology_time_stamp = ref_acq.physiology_time_stamp + imghdr.image_index = image_index + image_index += 1 + + mrd_image = Image(head=imghdr, data=combined) + yield mrd_image + + for acq in input: + if acq.idx.repetition != current_rep: + # If we have a current buffer pass it on + if buffer is not None and reference_acquisition is not None: + yield from produce_image(buffer, reference_acquisition) + + # Reset buffer + if acq.data.shape[-1] == eNx: + readout_length = eNx + else: + readout_length = rNx # Readout oversampling has been removed upstream + + buffer = np.zeros((ncontrasts, nslices, ncoils, eNz, eNy, readout_length), dtype=np.complex64) + current_rep = acq.idx.repetition + reference_acquisition = acq + + # Stuff into the buffer + if buffer is not None: + contrast = acq.idx.contrast if acq.idx.contrast is not None else 0 + slice = acq.idx.slice if acq.idx.slice is not None else 0 + k1 = acq.idx.kspace_encode_step_1 if acq.idx.kspace_encode_step_1 is not None else 0 + k2 = acq.idx.kspace_encode_step_2 if acq.idx.kspace_encode_step_2 is not None else 0 + buffer[contrast, slice, :, k2, k1 + ky_offset, :] = acq.data + + if buffer is not None and reference_acquisition is not None: + yield from produce_image(buffer, reference_acquisition) + buffer = None + reference_acquisition = None + +def reconstruct_ismrmrd_stream(input: BinaryIO, output: BinaryIO): + with ProtocolDeserializer(input) as reader, ProtocolSerializer(output) as writer: + stream = reader.deserialize() + head = next(stream, None) + if head is None: + raise Exception("Could not read ISMRMRD header") + if not isinstance(head, ismrmrdHeader): + raise Exception("First item in stream is not an ISMRMRD header") + writer.serialize(head) + for item in stream_item_sink( + accumulate_fft(head, + remove_oversampling(head, + acquisition_reader(stream)))): + writer.serialize(item) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Reconstructs an ISMRMRD stream") + parser.add_argument('-i', '--input', type=str, required=False, help="Input stream, defaults to stdin") + parser.add_argument('-o', '--output', type=str, required=False, help="Output stream, defaults to stdout") + args = parser.parse_args() + + input = args.input if args.input is not None else sys.stdin.buffer + output = args.output if args.output is not None else sys.stdout.buffer + + reconstruct_ismrmrd_stream(input, output) diff --git a/ismrmrd/__init__.py b/ismrmrd/__init__.py index 430cc7d..435a692 100644 --- a/ismrmrd/__init__.py +++ b/ismrmrd/__init__.py @@ -1,10 +1,11 @@ from .constants import * -from .acquisition import * -from .image import * +from .acquisition import AcquisitionHeader, Acquisition, EncodingCounters +from .image import ImageHeader, Image from .hdf5 import Dataset from .meta import Meta -from .waveform import * +from .waveform import WaveformHeader, Waveform from .file import File +from .serialization import ProtocolSerializer, ProtocolDeserializer from . import xsd diff --git a/ismrmrd/file.py b/ismrmrd/file.py index ce4b276..6cef0e1 100644 --- a/ismrmrd/file.py +++ b/ismrmrd/file.py @@ -1,11 +1,11 @@ import h5py -import numpy +import numpy as np -from .hdf5 import * -from .acquisition import * -from .waveform import * -from .image import * -from .xsd import ToXML +from .hdf5 import acquisition_header_dtype, acquisition_dtype, waveform_header_dtype, waveform_dtype, image_header_dtype +from .acquisition import Acquisition +from .waveform import Waveform +from .image import Image +from .xsd import ToXML, CreateFromDocument class DataWrapper: @@ -32,7 +32,7 @@ def __setitem__(self, key, value): except TypeError: iterable = [self.to_numpy(value)] - self.data[key] = numpy.array(iterable, dtype=self.datatype) + self.data[key] = np.array(iterable, dtype=self.datatype) def __repr__(self): return type(self).__name__ + " containing " + self.data.__repr__() @@ -98,7 +98,7 @@ def from_numpy(cls, raw): # of padding and alignment. Nu guarantees are given, so we need create a structured array # with a header to have the contents filled in correctly. We start with an array of # zeroes to avoid garbage in the padding bytes. - header_array = numpy.zeros((1,), dtype=waveform_header_dtype) + header_array = np.zeros((1,), dtype=waveform_header_dtype) header_array[0] = raw['head'] waveform = Waveform(header_array) @@ -149,9 +149,9 @@ def __setitem__(self, key, value): except TypeError: iterable = [self.to_numpy(value)] - self.headers[key] = numpy.stack([header for header, _, __ in iterable]) - self.data[key] = numpy.stack([data for _, data, __ in iterable]) - self.attributes[key] = numpy.stack([attributes for _, __, attributes in iterable]) + self.headers[key] = np.stack([header for header, _, __ in iterable]) + self.data[key] = np.stack([data for _, data, __ in iterable]) + self.attributes[key] = np.stack([attributes for _, __, attributes in iterable]) @classmethod def from_numpy(cls, header, data, attributes): @@ -252,7 +252,7 @@ def __set_acquisitions(self, acquisitions): if self.has_images(): raise TypeError("Cannot add acquisitions when images are present.") - buffer = numpy.array([Acquisitions.to_numpy(a) for a in acquisitions], dtype=acquisition_dtype) + buffer = np.array([Acquisitions.to_numpy(a) for a in acquisitions], dtype=acquisition_dtype) self.__del_acquisitions() self._contents.create_dataset('data',data=buffer,maxshape=(None,),chunks=True) @@ -275,7 +275,7 @@ def __set_waveforms(self, waveforms): raise TypeError("Cannot add waveforms when images are present.") converter = Waveforms(None) - buffer = numpy.array([converter.to_numpy(w) for w in waveforms], dtype=waveform_dtype) + buffer = np.array([converter.to_numpy(w) for w in waveforms], dtype=waveform_dtype) self.__del_waveforms() self._contents.create_dataset('waveforms', data=buffer, maxshape=(None,),chunks=True) @@ -303,9 +303,9 @@ def __set_images(self, images): images = list(images) - data = numpy.stack([image.data for image in images]) - headers = numpy.stack([np.frombuffer(image.getHead(), dtype=image_header_dtype) for image in images]) - attributes = numpy.stack([image.attribute_string for image in images]) + data = np.stack([image.data for image in images]) + headers = np.stack([np.frombuffer(image.getHead(), dtype=image_header_dtype) for image in images]) + attributes = np.stack([image.attribute_string for image in images]) self.__del_images() self._contents.create_dataset('data', data=data) @@ -323,7 +323,7 @@ def __del_images(self): def __get_header(self): if not self.has_header(): return None - return ismrmrd.xsd.CreateFromDocument(self._contents['xml'][0]) + return CreateFromDocument(self._contents['xml'][0]) def __set_header(self, header): self.__del_header() diff --git a/ismrmrd/image.py b/ismrmrd/image.py index cad7560..71973ff 100644 --- a/ismrmrd/image.py +++ b/ismrmrd/image.py @@ -293,14 +293,10 @@ def meta(self, val): raise RuntimeError("meta must be of type Meta or dict") @property - def matrix_size(self, warn=True): - if warn: - warnings.warn( - "This function currently returns a result that is inconsistent (transposed) " + - "compared to the matrix_size in the ImageHeader and from .getHead().matrix_size. " + - "This function will be made consistent in a future version and this message " + - "will be removed." - ) + def matrix_size(self): + """This function currently returns a result that is inconsistent (transposed) + compared to the matrix_size in the ImageHeader and from .getHead().matrix_size. + This function will be made consistent in a future version and this message will be removed.""" return self.__data.shape[1:4] @property diff --git a/ismrmrd/serialization.py b/ismrmrd/serialization.py new file mode 100644 index 0000000..93df092 --- /dev/null +++ b/ismrmrd/serialization.py @@ -0,0 +1,201 @@ +""" +Implements ProtocolSerializer and ProtocolDeserializer for streaming ISMRMRD objects (Acquisition, Image, Waveform, etc.) +""" +import struct +from typing import Union, BinaryIO, Any, Generator, cast +import numpy as np + +from ismrmrd.acquisition import Acquisition +from ismrmrd.image import Image, get_data_type_from_dtype, get_dtype_from_data_type +from ismrmrd.waveform import Waveform +from ismrmrd.xsd import ismrmrdHeader, CreateFromDocument + +from enum import IntEnum + +# Type alias for serializable objects +SerializableObject = Union[Acquisition, Image, Waveform, ismrmrdHeader, np.ndarray, str] + +class ISMRMRDMessageID(IntEnum): + UNPEEKED = 0 + CONFIG_FILE = 1 + CONFIG_TEXT = 2 + HEADER = 3 + CLOSE = 4 + TEXT = 5 + ACQUISITION = 1008 + IMAGE = 1022 + WAVEFORM = 1026 + NDARRAY = 1030 + +class ProtocolSerializer: + """ + Serializes ISMRMRD objects to a binary stream. + """ + + def __init__(self, stream: Union[BinaryIO, str]) -> None: + if isinstance(stream, str): + self._stream = cast(BinaryIO, open(stream, "wb")) + self._owns_stream = True + else: + self._stream = stream + self._owns_stream = False + + def __enter__(self) -> 'ProtocolSerializer': + return self + + def __exit__(self, exc_type: Union[type[BaseException], None], exc: Union[BaseException, None], traceback: Any) -> None: + try: + self.close() + except Exception as e: + if exc is None: + raise e + + def close(self) -> None: + self._write_message_id(ISMRMRDMessageID.CLOSE) + self._stream.flush() + if self._owns_stream: + self._stream.close() + + def _write_message_id(self, msgid: ISMRMRDMessageID) -> None: + self._stream.write(struct.pack(' None: + """ + Serializes an ISMRMRD object and writes to the configured stream. + """ + if isinstance(obj, Acquisition): + self._write_message_id(ISMRMRDMessageID.ACQUISITION) + obj.serialize_into(self._stream.write) + elif isinstance(obj, Image): + self._write_message_id(ISMRMRDMessageID.IMAGE) + obj.serialize_into(self._stream.write) + elif isinstance(obj, Waveform): + self._write_message_id(ISMRMRDMessageID.WAVEFORM) + obj.serialize_into(self._stream.write) + elif isinstance(obj, ismrmrdHeader): + self._write_message_id(ISMRMRDMessageID.HEADER) + self._serialize_ismrmrd_header(obj) + elif isinstance(obj, np.ndarray): + self._write_message_id(ISMRMRDMessageID.NDARRAY) + self._serialize_ndarray(obj) + elif isinstance(obj, str): + self._write_message_id(ISMRMRDMessageID.TEXT) + self._serialize_text(obj) + else: + raise TypeError(f"Unsupported type: {type(obj)}") + + def _serialize_ismrmrd_header(self, header: ismrmrdHeader) -> None: + xml_bytes = header.toXML().encode('utf-8') + self._stream.write(struct.pack(' None: + ver = 0 + dtype = get_data_type_from_dtype(arr.dtype) + ndim = arr.ndim + dims = arr.shape + self._stream.write(struct.pack(' None: + text_bytes = text.encode('utf-8') + self._stream.write(struct.pack(' None: + if isinstance(stream, str): + self._stream = cast(BinaryIO, open(stream, "rb")) + self._owns_stream = True + else: + self._stream = stream + self._owns_stream = False + + def __enter__(self) -> 'ProtocolDeserializer': + return self + + def __exit__(self, exc_type: Union[type[BaseException], None], exc: Union[BaseException, None], traceback: Any) -> None: + try: + self.close() + except Exception as e: + if exc is None: + raise e + + def close(self) -> None: + if self._owns_stream: + self._stream.close() + + def deserialize(self) -> Generator[SerializableObject, None, None]: + """ + Reads from the stream, yielding each ISMRMRD object as a generator. + """ + while True: + msg_id_bytes = self._stream.read(2) + if not msg_id_bytes or len(msg_id_bytes) < 2: + raise EOFError("End of stream or incomplete message ID") + msg_id = struct.unpack(' ismrmrdHeader: + length_bytes = self._stream.read(4) + if len(length_bytes) < 4: + raise EOFError("Incomplete header length") + length = struct.unpack(' np.ndarray: + header_fmt = ' str: + length_bytes = self._stream.read(4) + if len(length_bytes) < 4: + raise EOFError("Incomplete text length") + length = struct.unpack('/dev/null + +# Generate reference reconstructed image(s) +docker run -i --rm -v "${WORKDIR}":/tmp ${ISMRMRD_IMAGE} ismrmrd_hdf5_to_stream -i /tmp/testdata.h5 --use-stdout \ + | docker run -i --rm ${ISMRMRD_IMAGE} ismrmrd_stream_recon_cartesian_2d --use-stdin --use-stdout --output-magnitude \ + > reference-image-stream.ismrmrd + +# Pipe phantom dataset through the Python ProtocolSerializer to test compatibility +docker run --rm -i -v "${WORKDIR}":/tmp ${ISMRMRD_IMAGE} ismrmrd_hdf5_to_stream -i /tmp/testdata.h5 --use-stdout \ + | python "${PROJECT_DIR}"/utilities/ismrmrd_copy_stream.py > phantom.ismrmrd + +# Reconstruct the images using Python +python examples/stream_recon.py --input phantom.ismrmrd --output reconstructed.ismrmrd + +# Compare the images +python "${SCRIPT_DIR}"/validate_results.py --reference reference-image-stream.ismrmrd --testdata reconstructed.ismrmrd + +# Reconstruct the images using Python (stdin/stdout) +python examples/stream_recon.py < phantom.ismrmrd > reconstructed.ismrmrd + +# Compare the images again +python "${SCRIPT_DIR}"/validate_results.py --reference reference-image-stream.ismrmrd --testdata reconstructed.ismrmrd + +echo "Success!" \ No newline at end of file diff --git a/tests/end-to-end/validate_results.py b/tests/end-to-end/validate_results.py new file mode 100644 index 0000000..427bdde --- /dev/null +++ b/tests/end-to-end/validate_results.py @@ -0,0 +1,56 @@ +import os +import argparse +import numpy as np +from pathlib import Path + +import ismrmrd + +def test_basic_recon(reference_file, testdata_file): + reference_reader = ismrmrd.ProtocolDeserializer(reference_file) + reference_images = list(reference_reader.deserialize()) + + testdata_reader = ismrmrd.ProtocolDeserializer(testdata_file) + test_stream = list(testdata_reader.deserialize()) + # The reconstruction may or may not include the ismrmrdHeader + if len(test_stream) > 1: + assert len(test_stream) == 2 + test_header = test_stream[0] + assert isinstance(test_header, ismrmrd.xsd.ismrmrdHeader), "First object in test stream should be an IsmrmrdHeader" + test_images = test_stream[1:] + else: + test_images = test_stream + + assert len(test_images) == len(reference_images), "Expected matching number of reference and test images" + for ref_img, test_img in zip(reference_images, test_images): + ref_norm = normalize(ref_img.data) + test_norm = normalize(test_img.data) + assert np.allclose(ref_norm, test_norm), "Normalized reference and test images do not match closely enough" + + # assert test_img.version == ref_img.version + assert test_img.measurement_uid == ref_img.measurement_uid + assert test_img.position[:] == ref_img.position[:] + assert test_img.read_dir[:] == ref_img.read_dir[:] + assert test_img.phase_dir[:] == ref_img.phase_dir[:] + assert test_img.slice_dir[:] == ref_img.slice_dir[:] + assert test_img.patient_table_position[:] == ref_img.patient_table_position[:] + assert test_img.acquisition_time_stamp == ref_img.acquisition_time_stamp + assert test_img.physiology_time_stamp[:] == ref_img.physiology_time_stamp[:] + assert test_img.image_type == ref_img.image_type + assert test_img.image_index == ref_img.image_index + assert test_img.image_series_index == ref_img.image_series_index + assert test_img.user_int[:] == ref_img.user_int[:] + assert test_img.user_float[:] == ref_img.user_float[:] + + +# Z-score normalization function +def normalize(data): + return (data - np.mean(data)) / np.std(data) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Validate results of a streaming ismrmrd reconstruction') + parser.add_argument('--reference', type=str, help='Reference image stream', required=True) + parser.add_argument('--testdata', type=str, help='Test image stream', required=True) + args = parser.parse_args() + with open(args.testdata, 'rb') as testdata_file: + with open(args.reference, 'rb') as reference_file: + test_basic_recon(reference_file, testdata_file) diff --git a/tests/test_common.py b/tests/test_common.py index aa95764..e44acba 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,6 +5,55 @@ import random +import ismrmrd.xsd + + +example_header = """ + + + 32130323 + + + + + 64 + 64 + 1 + + + 300 + 300 + 40 + + + + + 64 + 64 + 1 + + + 300 + 300 + 40 + + + radial + + + + +""" + + +def create_example_ismrmrd_header(): + """Create a sample ISMRMRD header using the example from test_file.py.""" + return ismrmrd.xsd.CreateFromDocument(example_header) + +def create_random_ndarray(): + """Create a sample multi-dimensional numpy array.""" + # Create a 4D array to test multiple dimensions: (batch, channels, height, width) + return np.random.rand(2, 8, 64, 64).astype(np.float32) def random_tuple(size, random_fn): return tuple([random_fn() for _ in range(0, size)]) diff --git a/tests/test_file.py b/tests/test_file.py index c92f45f..5fa799a 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -272,46 +272,9 @@ def test_file_can_rewrite_data_and_images(): imageset.images = random_images(2) imageset.images = random_images(3) -example_header = """ - - - 32130323 - - - - - 64 - 64 - 1 - - - 300 - 300 - 40 - - - - - 64 - 64 - 1 - - - 300 - 300 - 40 - - - radial - - - - -""" - def test_file_can_read_and_write_headers(): filename = os.path.join(temp_dir, "file.h5") - header = ismrmrd.xsd.CreateFromDocument(example_header) + header = create_example_ismrmrd_header() with ismrmrd.File(filename) as file: dataset = file['dataset'] dataset.header = header diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..b7fe0b0 --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,269 @@ +from ctypes import c_float, c_int32, c_uint, c_uint16, c_uint32, c_uint64 +import io +import numpy as np +import test_common as common +from ismrmrd.acquisition import Acquisition, EncodingCounters, AcquisitionHeader +from ismrmrd.image import Image, ImageHeader +from ismrmrd.meta import Meta +from ismrmrd.waveform import Waveform, WaveformHeader +from ismrmrd.serialization import ProtocolSerializer, ProtocolDeserializer +from ismrmrd.xsd import ismrmrdHeader, CreateFromDocument +from test_file import example_header + +def make_acquisition(): + # Build header + header = AcquisitionHeader() + header.version = 1 + header.flags = 2 + header.measurement_uid = 123 + header.scan_counter = 1 + header.acquisition_time_stamp = 456 + header.physiology_time_stamp[:] = [1, 2, 3] + header.number_of_samples = 4 + header.available_channels = 2 + header.active_channels = 2 + header.channel_mask[:] = [1] * 16 + header.discard_pre = 0 + header.discard_post = 0 + header.center_sample = 2 + header.encoding_space_ref = 0 + header.trajectory_dimensions = 2 + header.sample_time_us = 1.0 + header.position[:] = [0.0, 0.0, 0.0] + header.read_dir[:] = [1.0, 0.0, 0.0] + header.phase_dir[:] = [0.0, 1.0, 0.0] + header.slice_dir[:] = [0.0, 0.0, 1.0] + header.patient_table_position[:] = [0.0, 0.0, 0.0] + header.idx = EncodingCounters() + header.idx.kspace_encode_step_1 = 0 + header.idx.kspace_encode_step_2 = 0 + header.idx.average = 0 + header.idx.slice = 0 + header.idx.contrast = 0 + header.idx.phase = 0 + header.idx.repetition = 0 + header.idx.set = 0 + header.idx.segment = 0 + header.idx.user[:] = [0] * 8 + header.user_int[:] = [0] * 8 + header.user_float[:] = [0.0] * 8 + # Build data arrays + trajectory = np.ones((header.number_of_samples, header.trajectory_dimensions), dtype='float32') + data = np.ones((header.active_channels, header.number_of_samples), dtype='complex64') + # Construct Acquisition + acq = Acquisition(header, data, trajectory) + return acq + +def make_image(): + # Build header + header = ImageHeader() + header.version = 1 + header.data_type = 5 # FLOAT + header.flags = 2 + header.measurement_uid = 123 + header.matrix_size[:] = [2, 2, 2] + header.field_of_view[:] = [1.0, 1.0, 1.0] + header.channels = 1 + header.position[:] = [0.0, 0.0, 0.0] + header.read_dir[:] = [1.0, 0.0, 0.0] + header.phase_dir[:] = [0.0, 1.0, 0.0] + header.slice_dir[:] = [0.0, 0.0, 1.0] + header.patient_table_position[:] = [0.0, 0.0, 0.0] + header.average = 0 + header.slice = 0 + header.contrast = 0 + header.phase = 0 + header.repetition = 0 + header.set = 0 + header.acquisition_time_stamp = 456 + header.physiology_time_stamp[:] = [1, 2, 3] + header.image_type = 1 + header.image_index = 1 + header.image_series_index = 1 + header.user_int[:] = [0] * 8 + header.user_float[:] = [0.0] * 8 + + data = np.ones((1, 2, 2, 2), dtype='float32') + meta = Meta({"foo": "bar"}) + img = Image(head=header, data=data, meta=meta) + return img + +def make_waveform(): + head = WaveformHeader() + head.version = 1 + head.flags = 2 + head.measurement_uid = 123 + head.scan_counter = 1 + head.time_stamp = 456 + head.number_of_samples = 4 + + head.channels = 2 + head.sample_time_us = 1.0 + head.waveform_id = 42 + + data = np.ones((2, 4), dtype=np.uint32) + wf = Waveform(head, data) + return wf + +def test_acquisition_serialization(): + acq = common.create_random_acquisition() + stream = io.BytesIO() + serializer = ProtocolSerializer(stream) + serializer.serialize(acq) + serializer.close() + stream.seek(0) + deserializer = ProtocolDeserializer(stream) + objects = list(deserializer.deserialize()) + assert len(objects) == 1 + acq2 = objects[0] + assert np.allclose(acq.data, acq2.data) + assert np.allclose(acq.traj, acq2.traj) + assert acq.number_of_samples == acq2.number_of_samples + assert acq.active_channels == acq2.active_channels + assert acq.flags == acq2.flags + assert acq.idx == acq2.idx + +def test_image_serialization(): + img = common.create_random_image() + stream = io.BytesIO() + serializer = ProtocolSerializer(stream) + serializer.serialize(img) + serializer.close() + stream.seek(0) + deserializer = ProtocolDeserializer(stream) + objects = list(deserializer.deserialize()) + assert len(objects) == 1 + img2 = objects[0] + assert np.allclose(img.data, img2.data) + assert img.matrix_size == img2.matrix_size + assert img.channels == img2.channels + assert img.meta == img2.meta + assert img.attribute_string == img2.attribute_string + +def test_waveform_serialization(): + wf = common.create_random_waveform() + stream = io.BytesIO() + serializer = ProtocolSerializer(stream) + serializer.serialize(wf) + serializer.close() + stream.seek(0) + deserializer = ProtocolDeserializer(stream) + objects = list(deserializer.deserialize()) + assert len(objects) == 1 + wf2 = objects[0] + assert np.allclose(wf.data, wf2.data) + assert wf.number_of_samples == wf2.number_of_samples + assert wf.channels == wf2.channels + assert wf.waveform_id == wf2.waveform_id + + +def test_ismrmrd_header_serialization(): + header = common.create_example_ismrmrd_header() + stream = io.BytesIO() + serializer = ProtocolSerializer(stream) + serializer.serialize(header) + serializer.close() + stream.seek(0) + deserializer = ProtocolDeserializer(stream) + objects = list(deserializer.deserialize()) + assert len(objects) == 1 + header2 = objects[0] + # Compare by converting both to XML strings + assert header.toXML() == header2.toXML() + + +def test_ndarray_serialization(): + arr = common.create_random_ndarray() + stream = io.BytesIO() + serializer = ProtocolSerializer(stream) + serializer.serialize(arr) + serializer.close() + stream.seek(0) + deserializer = ProtocolDeserializer(stream) + objects = list(deserializer.deserialize()) + assert len(objects) == 1 + arr2 = objects[0] + assert np.array_equal(arr, arr2) + assert arr.dtype == arr2.dtype + assert arr.shape == arr2.shape + + +def test_text_serialization(): + text = "Hello, ISMRMRD! This is a test text message with special characters: àáâãäåæçèéêë 123456789 !@#$%^&*()" + stream = io.BytesIO() + serializer = ProtocolSerializer(stream) + serializer.serialize(text) + serializer.close() + stream.seek(0) + deserializer = ProtocolDeserializer(stream) + objects = list(deserializer.deserialize()) + assert len(objects) == 1 + text2 = objects[0] + assert text == text2 + assert isinstance(text2, str) + + +def test_interleaved_serialization(): + """Test serialization and deserialization of multiple interleaved objects.""" + # Create test objects + acq = common.create_random_acquisition() + img = common.create_random_image() + wf = common.create_random_waveform() + header = common.create_example_ismrmrd_header() + arr = common.create_random_ndarray() + text = "Test interleaved text message 🚀" + + # Serialize all objects in a specific order + stream = io.BytesIO() + serializer = ProtocolSerializer(stream) + serializer.serialize(acq) + serializer.serialize(text) + serializer.serialize(img) + serializer.serialize(header) + serializer.serialize(arr) + serializer.serialize(wf) + serializer.close() + + # Deserialize and verify order and content + stream.seek(0) + deserializer = ProtocolDeserializer(stream) + objects = list(deserializer.deserialize()) + assert len(objects) == 6 + + # First object: Acquisition + obj1 = objects[0] + assert isinstance(obj1, Acquisition) + assert np.allclose(acq.data, obj1.data) + assert np.allclose(acq.traj, obj1.traj) + assert acq.flags == obj1.flags + + # Second object: Text + obj2 = objects[1] + assert isinstance(obj2, str) + assert text == obj2 + + # Third object: Image + obj3 = objects[2] + assert isinstance(obj3, Image) + assert np.allclose(img.data, obj3.data) + assert img.matrix_size == obj3.matrix_size + assert img.channels == obj3.channels + + # Fourth object: Header + obj4 = objects[3] + assert isinstance(obj4, type(header)) # ismrmrdHeader type + assert header.toXML() == obj4.toXML() + + # Fifth object: NDArray + obj5 = objects[4] + assert isinstance(obj5, np.ndarray) + assert np.array_equal(arr, obj5) + assert arr.dtype == obj5.dtype + assert arr.shape == obj5.shape + + # Sixth object: Waveform + obj6 = objects[5] + assert isinstance(obj6, Waveform) + assert np.allclose(wf.data, obj6.data) + assert wf.number_of_samples == obj6.number_of_samples + assert wf.channels == obj6.channels \ No newline at end of file diff --git a/utilities/ismrmrd_copy_stream.py b/utilities/ismrmrd_copy_stream.py new file mode 100644 index 0000000..2752ee6 --- /dev/null +++ b/utilities/ismrmrd_copy_stream.py @@ -0,0 +1,12 @@ +import sys +from ismrmrd.serialization import ProtocolSerializer, ProtocolDeserializer + + +def main(): + reader = ProtocolDeserializer(sys.stdin.buffer) + with ProtocolSerializer(sys.stdout.buffer) as writer: + for item in reader.deserialize(): + writer.serialize(item) + +if __name__ == '__main__': + main() \ No newline at end of file