@@ -29,7 +29,7 @@ class OrthoSlicer3D(object):
2929 -------
3030 >>> import numpy as np
3131 >>> a = np.sin(np.linspace(0,np.pi,20))
32- >>> b = np.sin(np.linspace(0,np.pi*5,20))
32+ >>> b = np.sin(np.linspace(0,np.pi*5,20))asa
3333 >>> data = np.outer(a,b)[..., np.newaxis]*a
3434 >>> OrthoSlicer3D(data).show() # doctest: +SKIP
3535 """
@@ -44,11 +44,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
4444 dimensions.
4545 affine : array-like | None
4646 Affine transform for the data. This is used to determine
47- how the data should be sliced for plotting into the X, Y,
48- and Z view axes. If None, identity is assumed. The aspect
49- ratio of the data are inferred from the affine transform.
47+ how the data should be sliced for plotting into the saggital,
48+ coronal, and axial view axes. If None, identity is assumed.
49+ The aspect ratio of the data are inferred from the affine
50+ transform.
5051 axes : tuple of mpl.Axes | None, optional
51- 3 or 4 axes instances for the X, Y, Z slices plus volumes,
52+ 3 or 4 axes instances for the 3 slices plus volumes,
5253 or None (default).
5354 cmap : str | instance of cmap, optional
5455 String or cmap instance specifying colormap.
@@ -63,39 +64,43 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
6364 affine = np .array (affine , float ) if affine is not None else np .eye (4 )
6465 if affine .ndim != 2 or affine .shape != (4 , 4 ):
6566 raise ValueError ('affine must be a 4x4 matrix' )
67+ # determine our orientation
6668 self ._affine = affine .copy ()
67- self ._codes = axcodes2ornt (aff2axcodes (self ._affine )) # XXX USE FOR ORDERING
68- print (self ._codes )
69+ codes = axcodes2ornt (aff2axcodes (self ._affine ))
70+ order = np .argsort ([c [0 ] for c in codes ])
71+ flips = np .array ([c [1 ] for c in codes ])[order ]
72+ self ._order = dict (x = int (order [0 ]), y = int (order [1 ]), z = int (order [2 ]))
73+ self ._flips = dict (x = flips [0 ], y = flips [1 ], z = flips [2 ])
6974 self ._scalers = np .abs (self ._affine ).max (axis = 0 )[:3 ]
7075 self ._inv_affine = np .linalg .inv (affine )
76+ # current volume info
7177 self ._volume_dims = data .shape [3 :]
7278 self ._current_vol_data = data [:, :, :, 0 ] if data .ndim > 3 else data
7379 self ._data = data
74- pcnt_range = (0 , 100 ) if pcnt_range is None else pcnt_range
7580 vmin , vmax = np .percentile (data , pcnt_range )
7681 del data
7782
7883 if axes is None : # make the axes
7984 # ^ +---------+ ^ +---------+
8085 # | | | | | |
86+ # | Sag | | Cor |
87+ # S | 1 | S | 2 |
8188 # | | | |
82- # z | 2 | z | 3 |
8389 # | | | |
84- # | | | | | |
85- # v +---------+ v +---------+
86- # <-- x --> <-- y -->
87- # ^ +---------+ ^ +---------+
88- # | | | | | |
90+ # +---------+ +---------+
91+ # A --> <-- R
92+ # ^ +---------+ +---------+
93+ # | | | | |
94+ # | Axial | | |
95+ # A | 3 | | 4 |
8996 # | | | |
90- # y | 1 | A | 4 |
9197 # | | | |
92- # | | | | | |
93- # v +---------+ v +---------+
94- # <-- x --> <-- t -->
98+ # +---------+ +---------+
99+ # <-- R <-- t -->
95100
96101 fig , axes = plt .subplots (2 , 2 )
97102 fig .set_size_inches (figsize , forward = True )
98- self ._axes = dict (x = axes [0 , 1 ], y = axes [0 , 0 ], z = axes [1 , 0 ],
103+ self ._axes = dict (x = axes [0 , 0 ], y = axes [0 , 1 ], z = axes [1 , 0 ],
99104 v = axes [1 , 1 ])
100105 plt .tight_layout (pad = 0.1 )
101106 if self .n_volumes <= 1 :
@@ -111,14 +116,15 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
111116
112117 # Start midway through each axis, idx is current slice number
113118 self ._ims , self ._sizes , self ._idx = dict (), dict (), dict ()
119+ self ._vol = 0
114120 colors = dict ()
115- for k , size in zip ('xyz' , self ._data .shape [:3 ]):
121+ for k in 'xyz' :
122+ size = self ._data .shape [self ._order [k ]]
116123 self ._idx [k ] = size // 2
117124 self ._ims [k ] = self ._axes [k ].imshow (self ._get_slice_data (k ), ** kw )
118125 self ._sizes [k ] = size
119126 colors [k ] = (0 , 1 , 0 )
120- self ._idx ['v' ] = 0
121- labels = dict (z = 'ILSR' , y = 'ALPR' , x = 'AIPS' )
127+ labels = dict (x = 'SAIP' , y = 'SLIR' , z = 'ALPR' )
122128
123129 # set up axis crosshairs
124130 self ._crosshairs = dict ()
@@ -231,7 +237,7 @@ def set_position(self, x=None, y=None, z=None, v=None):
231237 'image' )
232238 self ._set_vol_idx (v )
233239 draw = True
234- for key , val in zip ('zyx ' , (z , y , x )):
240+ for key , val in zip ('xyz ' , (x , y , z )):
235241 if val is not None :
236242 self ._set_viewer_slice (key , val )
237243 draw = True
@@ -241,9 +247,11 @@ def set_position(self, x=None, y=None, z=None, v=None):
241247
242248 def _get_voxel_levels (self ):
243249 """Get levels of the current voxel as a function of volume"""
244- y = self ._data [self ._idx ['x' ],
245- self ._idx ['y' ],
246- self ._idx ['z' ], :].ravel ()
250+ # XXX THIS IS WRONG
251+ #y = self._data[self._idx['x'],
252+ # self._idx['y'],
253+ # self._idx['z'], :].ravel()
254+ y = self ._data [0 , 0 , 0 , :].ravel ()
247255 y = np .concatenate ((y , [y [- 1 ]]))
248256 return y
249257
@@ -255,20 +263,34 @@ def _update_voxel_levels(self):
255263 def _set_vol_idx (self , idx ):
256264 """Change which volume is shown"""
257265 max_ = np .prod (self ._volume_dims )
258- self ._idx [ 'v' ] = max (min (int (round (idx )), max_ - 1 ), 0 )
266+ self ._vol = max (min (int (round (idx )), max_ - 1 ), 0 )
259267 # Must reset what is shown
260- self ._current_vol_data = self ._data [:, :, :, self ._idx [ 'v' ] ]
268+ self ._current_vol_data = self ._data [:, :, :, self ._vol ]
261269 for key in 'xyz' :
262270 self ._ims [key ].set_data (self ._get_slice_data (key ))
263- self ._volume_ax_objs ['patch' ].set_x (self ._idx [ 'v' ] - 0.5 )
271+ self ._volume_ax_objs ['patch' ].set_x (self ._vol - 0.5 )
264272
265273 def _get_slice_data (self , key ):
266274 """Helper to get the current slice image"""
267- ii = dict (x = 0 , y = 1 , z = 2 )[key ]
268- return np .take (self ._current_vol_data , self ._idx [key ], axis = ii ).T
275+ assert key in ['x' , 'y' , 'z' ]
276+ data = np .take (self ._current_vol_data , self ._idx [key ],
277+ axis = self ._order [key ])
278+ # saggital: get to S/A
279+ # coronal: get to S/L
280+ # axial: get to A/L
281+ xaxes = dict (x = 'y' , y = 'x' , z = 'x' )
282+ yaxes = dict (x = 'z' , y = 'z' , z = 'y' )
283+ if self ._order [xaxes [key ]] < self ._order [yaxes [key ]]:
284+ data = data .T
285+ if self ._flips [xaxes [key ]]:
286+ data = data [:, ::- 1 ]
287+ if self ._flips [yaxes [key ]]:
288+ data = data [::- 1 ]
289+ return data
269290
270291 def _set_viewer_slice (self , key , idx ):
271292 """Helper to set a viewer slice number"""
293+ assert key in ['x' , 'y' , 'z' ]
272294 self ._idx [key ] = max (min (int (round (idx )), self ._sizes [key ] - 1 ), 0 )
273295 self ._ims [key ].set_data (self ._get_slice_data (key ))
274296 for fun in self ._cross_setters [key ]:
@@ -293,7 +315,9 @@ def _on_scroll(self, event):
293315 if self .n_volumes <= 1 :
294316 return
295317 key = 'v' # shift: change volume in any axis
296- idx = self ._idx [key ] + (delta if event .button == 'up' else - delta )
318+ assert key in ['x' , 'y' , 'z' , 'v' ]
319+ idx = self ._idx [key ] if key != 'v' else self ._vol
320+ idx += delta if event .button == 'up' else - delta
297321 if key == 'v' :
298322 self ._set_vol_idx (idx )
299323 else :
0 commit comments