@@ -70,10 +70,9 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
7070 # determine our orientation
7171 self ._affine = affine .copy ()
7272 codes = axcodes2ornt (aff2axcodes (self ._affine ))
73- order = np .argsort ([c [0 ] for c in codes ])
74- flips = np .array ([c [1 ] < 0 for c in codes ])[order ]
75- self ._order = dict (x = int (order [0 ]), y = int (order [1 ]), z = int (order [2 ]))
76- self ._flips = dict (x = flips [0 ], y = flips [1 ], z = flips [2 ])
73+ self ._order = np .argsort ([c [0 ] for c in codes ])
74+ self ._flips = np .array ([c [1 ] < 0 for c in codes ])[self ._order ]
75+ self ._flips = list (self ._flips ) + [False ] # add volume dim
7776 self ._scalers = np .abs (self ._affine ).max (axis = 0 )[:3 ]
7877 self ._inv_affine = np .linalg .inv (affine )
7978 # current volume info
@@ -87,56 +86,54 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
8786 # ^ +---------+ ^ +---------+
8887 # | | | | | |
8988 # | Sag | | Cor |
90- # S | 1 | S | 2 |
89+ # S | 0 | S | 1 |
9190 # | | | |
9291 # | | | |
9392 # +---------+ +---------+
9493 # A --> <-- R
9594 # ^ +---------+ +---------+
9695 # | | | | |
9796 # | Axial | | Vol |
98- # A | 3 | | 4 |
97+ # A | 2 | | 3 |
9998 # | | | |
10099 # | | | |
101100 # +---------+ +---------+
102101 # <-- R <-- t -->
103102
104103 fig , axes = plt .subplots (2 , 2 )
105104 fig .set_size_inches (figsize , forward = True )
106- self ._axes = dict (x = axes [0 , 0 ], y = axes [0 , 1 ], z = axes [1 , 0 ],
107- v = axes [1 , 1 ])
105+ self ._axes = [axes [0 , 0 ], axes [0 , 1 ], axes [1 , 0 ], axes [1 , 1 ]]
108106 plt .tight_layout (pad = 0.1 )
109107 if self .n_volumes <= 1 :
110- fig .delaxes (self ._axes ['v' ])
111- del self ._axes [ 'v' ]
108+ fig .delaxes (self ._axes [3 ])
109+ self ._axes . pop ( - 1 )
112110 else :
113- self ._axes = dict ( z = axes [0 ], y = axes [1 ], x = axes [2 ])
111+ self ._axes = [ axes [0 ], axes [1 ], axes [2 ]]
114112 if len (axes ) > 3 :
115- self ._axes [ 'v' ] = axes [3 ]
113+ self ._axes . append ( axes [3 ])
116114
117115 # Start midway through each axis, idx is current slice number
118- self ._ims , self ._sizes , self . _data_idx = dict (), dict (), dict ()
116+ self ._ims , self ._data_idx = list (), list ()
119117
120118 # set up axis crosshairs
121- self ._crosshairs = dict ()
122- r = [self ._scalers [self ._order ['z' ]] / self ._scalers [self ._order ['y' ]],
123- self ._scalers [self ._order ['z' ]] / self ._scalers [self ._order ['x' ]],
124- self ._scalers [self ._order ['y' ]] / self ._scalers [self ._order ['x' ]]]
125- for k in 'xyz' :
126- self ._sizes [k ] = self ._data .shape [self ._order [k ]]
127- for k , xax , yax , ratio , label in zip ('xyz' , 'yxx' , 'zzy' , r ,
128- ('SAIP' , 'SLIR' , 'ALPR' )):
129- ax = self ._axes [k ]
119+ self ._crosshairs = [None ] * 3
120+ r = [self ._scalers [self ._order [2 ]] / self ._scalers [self ._order [1 ]],
121+ self ._scalers [self ._order [2 ]] / self ._scalers [self ._order [0 ]],
122+ self ._scalers [self ._order [1 ]] / self ._scalers [self ._order [0 ]]]
123+ self ._sizes = [self ._data .shape [o ] for o in self ._order ]
124+ for ii , xax , yax , ratio , label in zip ([0 , 1 , 2 ], [1 , 0 , 0 ], [2 , 2 , 1 ],
125+ r , ('SAIP' , 'SLIR' , 'ALPR' )):
126+ ax = self ._axes [ii ]
130127 d = np .zeros ((self ._sizes [yax ], self ._sizes [xax ]))
131- self . _ims [ k ] = self ._axes [k ].imshow (d , vmin = vmin , vmax = vmax ,
132- aspect = 1 , cmap = cmap ,
133- interpolation = 'nearest' ,
134- origin = 'lower' )
128+ im = self ._axes [ii ].imshow (d , vmin = vmin , vmax = vmax , aspect = 1 ,
129+ cmap = cmap , interpolation = 'nearest' ,
130+ origin = 'lower' )
131+ self . _ims . append ( im )
135132 vert = ax .plot ([0 ] * 2 , [- 0.5 , self ._sizes [yax ] - 0.5 ],
136133 color = (0 , 1 , 0 ), linestyle = '-' )[0 ]
137134 horiz = ax .plot ([- 0.5 , self ._sizes [xax ] - 0.5 ], [0 ] * 2 ,
138135 color = (0 , 1 , 0 ), linestyle = '-' )[0 ]
139- self ._crosshairs [k ] = dict (vert = vert , horiz = horiz )
136+ self ._crosshairs [ii ] = dict (vert = vert , horiz = horiz )
140137 # add text labels (top, right, bottom, left)
141138 lims = [0 , self ._sizes [xax ], 0 , self ._sizes [yax ]]
142139 bump = 0.01
@@ -156,12 +153,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
156153 ax .set_frame_on (False )
157154 ax .axes .get_yaxis ().set_visible (False )
158155 ax .axes .get_xaxis ().set_visible (False )
159- self ._data_idx [ k ] = 0
160- self ._data_idx [ 'v' ] = - 1
156+ self ._data_idx . append ( 0 )
157+ self ._data_idx . append ( - 1 ) # volume
161158
162159 # Set up volumes axis
163- if self .n_volumes > 1 and 'v' in self ._axes :
164- ax = self ._axes ['v' ]
160+ if self .n_volumes > 1 and len ( self ._axes ) > 3 :
161+ ax = self ._axes [3 ]
165162 ax .set_axis_bgcolor ('k' )
166163 ax .set_title ('Volumes' )
167164 y = np .zeros (self .n_volumes + 1 )
@@ -179,7 +176,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
179176 ax .set_ylim (yl )
180177 self ._volume_ax_objs = dict (step = step , patch = patch )
181178
182- self ._figs = set ([a .figure for a in self ._axes . values () ])
179+ self ._figs = set ([a .figure for a in self ._axes ])
183180 for fig in self ._figs :
184181 fig .canvas .mpl_connect ('scroll_event' , self ._on_scroll )
185182 fig .canvas .mpl_connect ('motion_notify_event' , self ._on_mouse )
@@ -287,14 +284,14 @@ def set_volume_idx(self, v):
287284
288285 def _set_volume_index (self , v , update_slices = True ):
289286 """Set the plot data using a volume index"""
290- v = self ._data_idx ['v' ] if v is None else int (round (v ))
291- if v == self ._data_idx ['v' ]:
287+ v = self ._data_idx [3 ] if v is None else int (round (v ))
288+ if v == self ._data_idx [3 ]:
292289 return
293290 max_ = np .prod (self ._volume_dims )
294- self ._data_idx ['v' ] = max (min (int (round (v )), max_ - 1 ), 0 )
291+ self ._data_idx [3 ] = max (min (int (round (v )), max_ - 1 ), 0 )
295292 idx = (slice (None ), slice (None ), slice (None ))
296293 if self ._data .ndim > 3 :
297- idx = idx + tuple (np .unravel_index (self ._data_idx ['v' ],
294+ idx = idx + tuple (np .unravel_index (self ._data_idx [3 ],
298295 self ._volume_dims ))
299296 self ._current_vol_data = self ._data [idx ]
300297 # update all of our slice plots
@@ -314,108 +311,104 @@ def _set_position(self, x, y, z, notify=True):
314311 # deal with slicing appropriately
315312 self ._position [:3 ] = [x , y , z ]
316313 idxs = np .dot (self ._inv_affine , self ._position )[:3 ]
317- for key , idx in zip ('xyz' , idxs ):
318- self ._data_idx [key ] = max (min (int (round (idx )),
319- self ._sizes [key ] - 1 ), 0 )
320- for key in 'xyz' :
314+ for ii , (size , idx ) in enumerate (zip (self ._sizes , idxs )):
315+ self ._data_idx [ii ] = max (min (int (round (idx )), size - 1 ), 0 )
316+ for ii in range (3 ):
321317 # saggital: get to S/A
322318 # coronal: get to S/L
323319 # axial: get to A/L
324- data = np .take (self ._current_vol_data , self ._data_idx [key ],
325- axis = self ._order [key ])
326- xax = dict ( x = 'y' , y = 'x' , z = 'x' )[ key ]
327- yax = dict ( x = 'z' , y = 'z' , z = 'y' )[ key ]
320+ data = np .take (self ._current_vol_data , self ._data_idx [ii ],
321+ axis = self ._order [ii ])
322+ xax = [ 1 , 0 , 0 ][ ii ]
323+ yax = [ 2 , 2 , 1 ][ ii ]
328324 if self ._order [xax ] < self ._order [yax ]:
329325 data = data .T
330326 if self ._flips [xax ]:
331327 data = data [:, ::- 1 ]
332328 if self ._flips [yax ]:
333329 data = data [::- 1 ]
334- self ._ims [key ].set_data (data )
330+ self ._ims [ii ].set_data (data )
335331 # deal with crosshairs
336- loc = self ._data_idx [key ]
337- if self ._flips [key ]:
338- loc = self ._sizes [key ] - loc
332+ loc = self ._data_idx [ii ]
333+ if self ._flips [ii ]:
334+ loc = self ._sizes [ii ] - loc
339335 loc = [loc ] * 2
340- if key == 'x' :
341- self ._crosshairs ['z' ]['vert' ].set_xdata (loc )
342- self ._crosshairs ['y' ]['vert' ].set_xdata (loc )
343- elif key == 'y' :
344- self ._crosshairs ['z' ]['horiz' ].set_ydata (loc )
345- self ._crosshairs ['x' ]['vert' ].set_xdata (loc )
346- else : # key == 'z'
347- self ._crosshairs ['y' ]['horiz' ].set_ydata (loc )
348- self ._crosshairs ['x' ]['horiz' ].set_ydata (loc )
336+ if ii == 0 :
337+ self ._crosshairs [2 ]['vert' ].set_xdata (loc )
338+ self ._crosshairs [1 ]['vert' ].set_xdata (loc )
339+ elif ii == 1 :
340+ self ._crosshairs [2 ]['horiz' ].set_ydata (loc )
341+ self ._crosshairs [0 ]['vert' ].set_xdata (loc )
342+ else : # ii == 2
343+ self ._crosshairs [1 ]['horiz' ].set_ydata (loc )
344+ self ._crosshairs [0 ]['horiz' ].set_ydata (loc )
349345
350346 # Update volume trace
351- if self .n_volumes > 1 and 'v' in self ._axes :
352- idx = [0 ] * 3
353- for key in 'xyz' :
354- idx [self ._order [key ]] = self ._data_idx [key ]
355- vdata = self ._data [idx [ 0 ], idx [ 1 ], idx [ 2 ], : ].ravel ()
347+ if self .n_volumes > 1 and len ( self ._axes ) > 3 :
348+ idx = [None , Ellipsis ] * 3
349+ for ii in range ( 3 ) :
350+ idx [self ._order [ii ]] = self ._data_idx [ii ]
351+ vdata = self ._data [idx ].ravel ()
356352 vdata = np .concatenate ((vdata , [vdata [- 1 ]]))
357- self ._volume_ax_objs ['patch' ].set_x (self ._data_idx ['v' ] - 0.5 )
353+ self ._volume_ax_objs ['patch' ].set_x (self ._data_idx [3 ] - 0.5 )
358354 self ._volume_ax_objs ['step' ].set_ydata (vdata )
359355 if notify :
360356 self ._notify_links ()
361357 self ._changing = False
362358
363359 # Matplotlib handlers ####################################################
364360 def _in_axis (self , event ):
365- """Return axis key if within one of our axes, else None"""
361+ """Return axis index if within one of our axes, else None"""
366362 if getattr (event , 'inaxes' ) is None :
367363 return None
368- for key , ax in self ._axes . items ( ):
364+ for ii , ax in enumerate ( self ._axes ):
369365 if event .inaxes is ax :
370- return key
366+ return ii
371367
372368 def _on_scroll (self , event ):
373369 """Handle mpl scroll wheel event"""
374370 assert event .button in ('up' , 'down' )
375- key = self ._in_axis (event )
376- if key is None :
371+ ii = self ._in_axis (event )
372+ if ii is None :
377373 return
378374 if event .key is not None and 'shift' in event .key :
379375 if self .n_volumes <= 1 :
380376 return
381- key = 'v' # shift: change volume in any axis
382- assert key in [ 'x' , 'y' , 'z' , 'v' ]
377+ ii = 3 # shift: change volume in any axis
378+ assert ii in range ( 4 )
383379 dv = 10. if event .key is not None and 'control' in event .key else 1.
384380 dv *= 1. if event .button == 'up' else - 1.
385- dv *= - 1 if self ._flips . get ( key , False ) else 1
386- val = self ._data_idx [key ] + dv
387- if key == 'v' :
381+ dv *= - 1 if self ._flips [ ii ] else 1
382+ val = self ._data_idx [ii ] + dv
383+ if ii == 3 :
388384 self ._set_volume_index (val )
389385 else :
390- coords = {key : val }
391- for k in 'xyz' :
392- if k not in coords :
393- coords [k ] = self ._data_idx [k ]
394- coords = np .array ([coords ['x' ], coords ['y' ], coords ['z' ], 1. ])
395- coords = np .dot (self ._affine , coords )[:3 ]
396- self ._set_position (coords [0 ], coords [1 ], coords [2 ])
386+ coords = [self ._data_idx [k ] for k in range (3 )] + [1. ]
387+ coords [ii ] = val
388+ self ._set_position (* np .dot (self ._affine , coords )[:3 ])
397389 self ._draw ()
398390
399391 def _on_mouse (self , event ):
400392 """Handle mpl mouse move and button press events"""
401393 if event .button != 1 : # only enabled while dragging
402394 return
403- key = self ._in_axis (event )
404- if key is None :
395+ ii = self ._in_axis (event )
396+ if ii is None :
405397 return
406- if key == 'v' :
398+ if ii == 3 :
407399 # volume plot directly translates
408400 self ._set_volume_index (event .xdata )
409401 else :
410402 # translate click xdata/ydata to physical position
411- xax , yax = dict ( x = 'yz' , y = 'xz' , z = 'xy' )[ key ]
403+ xax , yax = [[ 1 , 2 ], [ 0 , 2 ], [ 0 , 1 ]][ ii ]
412404 x , y = event .xdata , event .ydata
413405 x = self ._sizes [xax ] - x if self ._flips [xax ] else x
414406 y = self ._sizes [yax ] - y if self ._flips [yax ] else y
415- idxs = {xax : x , yax : y , key : self ._data_idx [key ]}
416- idxs = np .array ([idxs ['x' ], idxs ['y' ], idxs ['z' ], 1. ])
417- pos = np .dot (self ._affine , idxs )[:3 ]
418- self ._set_position (* pos )
407+ idxs = [None , None , None , 1. ]
408+ idxs [xax ] = x
409+ idxs [yax ] = y
410+ idxs [ii ] = self ._data_idx [ii ]
411+ self ._set_position (* np .dot (self ._affine , idxs )[:3 ])
419412 self ._draw ()
420413
421414 def _on_keypress (self , event ):
@@ -425,14 +418,14 @@ def _on_keypress(self, event):
425418
426419 def _draw (self ):
427420 """Update all four (or three) plots"""
428- for key in 'xyz' :
429- ax , im = self ._axes [key ], self . _ims [ key ]
430- ax .draw_artist (im )
431- for line in self ._crosshairs [key ].values ():
421+ for ii in range ( 3 ) :
422+ ax = self ._axes [ii ]
423+ ax .draw_artist (self . _ims [ ii ] )
424+ for line in self ._crosshairs [ii ].values ():
432425 ax .draw_artist (line )
433426 ax .figure .canvas .blit (ax .bbox )
434- if self .n_volumes > 1 and 'v' in self ._axes : # user might only pass 3
435- ax = self ._axes ['v' ]
427+ if self .n_volumes > 1 and len ( self ._axes ) > 3 :
428+ ax = self ._axes [3 ]
436429 ax .draw_artist (ax .patch ) # axis bgcolor to erase old lines
437430 for key in ('step' , 'patch' ):
438431 ax .draw_artist (self ._volume_ax_objs [key ])
0 commit comments