1818from pytensor .tensor .blockwise import Blockwise
1919from pytensor .tensor .type import (
2020 Variable ,
21+ dmatrix ,
2122 dvector ,
22- lscalar ,
23+ iscalar ,
2324 matrix ,
2425 scalar ,
2526 tensor ,
@@ -37,12 +38,16 @@ def __init__(self, hermitian):
3738 def make_node (self , x ):
3839 x = as_tensor_variable (x )
3940 assert x .ndim == 2
40- return Apply (self , [x ], [x .type ()])
41+ if x .type .numpy_dtype .kind in "ibu" :
42+ out_dtype = "float64"
43+ else :
44+ out_dtype = x .dtype
45+ return Apply (self , [x ], [matrix (shape = x .type .shape , dtype = out_dtype )])
4146
4247 def perform (self , node , inputs , outputs ):
4348 (x ,) = inputs
4449 (z ,) = outputs
45- z [0 ] = np .linalg .pinv (x , hermitian = self .hermitian ). astype ( x . dtype )
50+ z [0 ] = np .linalg .pinv (x , hermitian = self .hermitian )
4651
4752 def L_op (self , inputs , outputs , g_outputs ):
4853 r"""The gradient function should return
@@ -117,12 +122,16 @@ def __init__(self):
117122 def make_node (self , x ):
118123 x = as_tensor_variable (x )
119124 assert x .ndim == 2
120- return Apply (self , [x ], [x .type ()])
125+ if x .type .numpy_dtype .kind in "ibu" :
126+ out_dtype = "float64"
127+ else :
128+ out_dtype = x .dtype
129+ return Apply (self , [x ], [matrix (shape = x .type .shape , dtype = out_dtype )])
121130
122131 def perform (self , node , inputs , outputs ):
123132 (x ,) = inputs
124133 (z ,) = outputs
125- z [0 ] = np .linalg .inv (x ). astype ( x . dtype )
134+ z [0 ] = np .linalg .inv (x )
126135
127136 def grad (self , inputs , g_outputs ):
128137 r"""The gradient function should return
@@ -216,14 +225,18 @@ def make_node(self, x):
216225 raise ValueError (
217226 f"Determinant not defined for non-square matrix inputs. Shape received is { x .type .shape } "
218227 )
219- o = scalar (dtype = x .dtype )
228+ if x .type .numpy_dtype .kind in "ibu" :
229+ out_dtype = "float64"
230+ else :
231+ out_dtype = x .dtype
232+ o = scalar (dtype = out_dtype )
220233 return Apply (self , [x ], [o ])
221234
222235 def perform (self , node , inputs , outputs ):
223236 (x ,) = inputs
224237 (z ,) = outputs
225238 try :
226- z [0 ] = np .asarray (np .linalg .det (x ), dtype = x . dtype )
239+ z [0 ] = np .asarray (np .linalg .det (x ))
227240 except Exception as e :
228241 raise ValueError ("Failed to compute determinant" , x ) from e
229242
@@ -254,15 +267,19 @@ class SLogDet(Op):
254267 def make_node (self , x ):
255268 x = as_tensor_variable (x )
256269 assert x .ndim == 2
257- sign = scalar (dtype = x .dtype )
258- det = scalar (dtype = x .dtype )
270+ if x .type .numpy_dtype .kind in "ibu" :
271+ out_dtype = "float64"
272+ else :
273+ out_dtype = x .dtype
274+ sign = scalar (dtype = out_dtype )
275+ det = scalar (dtype = out_dtype )
259276 return Apply (self , [x ], [sign , det ])
260277
261278 def perform (self , node , inputs , outputs ):
262279 (x ,) = inputs
263280 (sign , det ) = outputs
264281 try :
265- sign [0 ], det [0 ] = (np .array (z , dtype = x . dtype ) for z in np .linalg .slogdet (x ))
282+ sign [0 ], det [0 ] = (np .array (z ) for z in np .linalg .slogdet (x ))
266283 except Exception as e :
267284 raise ValueError ("Failed to compute determinant" , x ) from e
268285
@@ -735,9 +752,9 @@ def make_node(self, x, y, rcond):
735752 self ,
736753 [x , y , rcond ],
737754 [
738- matrix (),
755+ dmatrix (),
739756 dvector (),
740- lscalar (),
757+ iscalar (),
741758 dvector (),
742759 ],
743760 )
@@ -746,7 +763,7 @@ def perform(self, node, inputs, outputs):
746763 zz = np .linalg .lstsq (inputs [0 ], inputs [1 ], inputs [2 ])
747764 outputs [0 ][0 ] = zz [0 ]
748765 outputs [1 ][0 ] = zz [1 ]
749- outputs [2 ][0 ] = np .array (zz [2 ])
766+ outputs [2 ][0 ] = np .asarray (zz [2 ])
750767 outputs [3 ][0 ] = zz [3 ]
751768
752769
0 commit comments