@@ -201,6 +201,21 @@ def _axis_none_keepdims(x, ndim, keepdims):
201201 x = torch .unsqueeze (x , 0 )
202202 return x
203203
204+ def _reduce_multiple_axes (f , x , axis , keepdims = False , ** kwargs ):
205+ # Some reductions don't support multiple axes
206+ # (https://github.com/pytorch/pytorch/issues/56586).
207+ axes = _normalize_axes (axis , x .ndim )
208+ for a in reversed (axes ):
209+ x = torch .movedim (x , a , - 1 )
210+ x = torch .flatten (x , - len (axes ))
211+
212+ out = f (x , - 1 , ** kwargs )
213+
214+ if keepdims :
215+ for a in axes :
216+ out = torch .unsqueeze (out , a )
217+ return out
218+
204219def prod (x : array ,
205220 / ,
206221 * ,
@@ -226,14 +241,7 @@ def prod(x: array,
226241 # torch.prod doesn't support multiple axes
227242 # (https://github.com/pytorch/pytorch/issues/56586).
228243 if isinstance (axis , tuple ):
229- axes = _normalize_axes (axis , x .ndim )
230- for i , a in enumerate (axes ):
231- if keepdims :
232- x = torch .prod (x , a , dtype = dtype , ** kwargs )
233- x = torch .unsqueeze (x , a )
234- else :
235- x = torch .prod (x , a - i , dtype = dtype , ** kwargs )
236- return x
244+ return _reduce_multiple_axes (torch .prod , x , axis , keepdims = keepdims , dtype = dtype , ** kwargs )
237245 if axis is None :
238246 # torch doesn't support keepdims with axis=None
239247 # (https://github.com/pytorch/pytorch/issues/71209)
@@ -281,21 +289,15 @@ def any(x: array,
281289 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
282290 keepdims : bool = False ,
283291 ** kwargs ) -> array :
284- # torch.any doesn't support multiple axes
285- # (https://github.com/pytorch/pytorch/issues/56586).
286292 x = torch .asarray (x )
287293 ndim = x .ndim
288294 if axis == ():
289295 return x .to (torch .bool )
296+ # torch.any doesn't support multiple axes
297+ # (https://github.com/pytorch/pytorch/issues/56586).
290298 if isinstance (axis , tuple ):
291- axes = _normalize_axes (axis , x .ndim )
292- for i , a in enumerate (axes ):
293- if keepdims :
294- x = torch .any (x , a , ** kwargs )
295- x = torch .unsqueeze (x , a )
296- else :
297- x = torch .any (x , a - i , ** kwargs )
298- return x .to (torch .bool )
299+ res = _reduce_multiple_axes (torch .any , x , axis , keepdims = keepdims , ** kwargs )
300+ return res .to (torch .bool )
299301 if axis is None :
300302 # torch doesn't support keepdims with axis=None
301303 # (https://github.com/pytorch/pytorch/issues/71209)
@@ -312,21 +314,15 @@ def all(x: array,
312314 axis : Optional [Union [int , Tuple [int , ...]]] = None ,
313315 keepdims : bool = False ,
314316 ** kwargs ) -> array :
315- # torch.all doesn't support multiple axes
316- # (https://github.com/pytorch/pytorch/issues/56586).
317317 x = torch .asarray (x )
318318 ndim = x .ndim
319319 if axis == ():
320320 return x .to (torch .bool )
321+ # torch.all doesn't support multiple axes
322+ # (https://github.com/pytorch/pytorch/issues/56586).
321323 if isinstance (axis , tuple ):
322- axes = _normalize_axes (axis , ndim )
323- for i , a in enumerate (axes ):
324- if keepdims :
325- x = torch .all (x , a , ** kwargs )
326- x = torch .unsqueeze (x , a )
327- else :
328- x = torch .all (x , a - i , ** kwargs )
329- return x .to (torch .bool )
324+ res = _reduce_multiple_axes (torch .all , x , axis , keepdims = keepdims , ** kwargs )
325+ return res .to (torch .bool )
330326 if axis is None :
331327 # torch doesn't support keepdims with axis=None
332328 # (https://github.com/pytorch/pytorch/issues/71209)
0 commit comments