Skip to content

Commit c56bc7e

Browse files
committed
Linalg Ops: Align output dtypes with those of numpy/scipy
1 parent 7a2c5db commit c56bc7e

File tree

3 files changed

+45
-23
lines changed

3 files changed

+45
-23
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from pytensor.tensor.blockwise import Blockwise
1919
from pytensor.tensor.type import (
2020
Variable,
21+
dmatrix,
2122
dvector,
22-
lscalar,
23+
iscalar,
2324
matrix,
2425
scalar,
2526
tensor,
@@ -37,12 +38,16 @@ def __init__(self, hermitian):
3738
def make_node(self, x):
3839
x = as_tensor_variable(x)
3940
assert x.ndim == 2
40-
return Apply(self, [x], [x.type()])
41+
if x.type.numpy_dtype.kind in "ibu":
42+
out_dtype = "float64"
43+
else:
44+
out_dtype = x.dtype
45+
return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)])
4146

4247
def perform(self, node, inputs, outputs):
4348
(x,) = inputs
4449
(z,) = outputs
45-
z[0] = np.linalg.pinv(x, hermitian=self.hermitian).astype(x.dtype)
50+
z[0] = np.linalg.pinv(x, hermitian=self.hermitian)
4651

4752
def L_op(self, inputs, outputs, g_outputs):
4853
r"""The gradient function should return
@@ -117,12 +122,16 @@ def __init__(self):
117122
def make_node(self, x):
118123
x = as_tensor_variable(x)
119124
assert x.ndim == 2
120-
return Apply(self, [x], [x.type()])
125+
if x.type.numpy_dtype.kind in "ibu":
126+
out_dtype = "float64"
127+
else:
128+
out_dtype = x.dtype
129+
return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)])
121130

122131
def perform(self, node, inputs, outputs):
123132
(x,) = inputs
124133
(z,) = outputs
125-
z[0] = np.linalg.inv(x).astype(x.dtype)
134+
z[0] = np.linalg.inv(x)
126135

127136
def grad(self, inputs, g_outputs):
128137
r"""The gradient function should return
@@ -216,14 +225,18 @@ def make_node(self, x):
216225
raise ValueError(
217226
f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}"
218227
)
219-
o = scalar(dtype=x.dtype)
228+
if x.type.numpy_dtype.kind in "ibu":
229+
out_dtype = "float64"
230+
else:
231+
out_dtype = x.dtype
232+
o = scalar(dtype=out_dtype)
220233
return Apply(self, [x], [o])
221234

222235
def perform(self, node, inputs, outputs):
223236
(x,) = inputs
224237
(z,) = outputs
225238
try:
226-
z[0] = np.asarray(np.linalg.det(x), dtype=x.dtype)
239+
z[0] = np.asarray(np.linalg.det(x))
227240
except Exception as e:
228241
raise ValueError("Failed to compute determinant", x) from e
229242

@@ -254,15 +267,19 @@ class SLogDet(Op):
254267
def make_node(self, x):
255268
x = as_tensor_variable(x)
256269
assert x.ndim == 2
257-
sign = scalar(dtype=x.dtype)
258-
det = scalar(dtype=x.dtype)
270+
if x.type.numpy_dtype.kind in "ibu":
271+
out_dtype = "float64"
272+
else:
273+
out_dtype = x.dtype
274+
sign = scalar(dtype=out_dtype)
275+
det = scalar(dtype=out_dtype)
259276
return Apply(self, [x], [sign, det])
260277

261278
def perform(self, node, inputs, outputs):
262279
(x,) = inputs
263280
(sign, det) = outputs
264281
try:
265-
sign[0], det[0] = (np.array(z, dtype=x.dtype) for z in np.linalg.slogdet(x))
282+
sign[0], det[0] = (np.array(z) for z in np.linalg.slogdet(x))
266283
except Exception as e:
267284
raise ValueError("Failed to compute determinant", x) from e
268285

@@ -735,9 +752,9 @@ def make_node(self, x, y, rcond):
735752
self,
736753
[x, y, rcond],
737754
[
738-
matrix(),
755+
dmatrix(),
739756
dvector(),
740-
lscalar(),
757+
iscalar(),
741758
dvector(),
742759
],
743760
)
@@ -746,7 +763,7 @@ def perform(self, node, inputs, outputs):
746763
zz = np.linalg.lstsq(inputs[0], inputs[1], inputs[2])
747764
outputs[0][0] = zz[0]
748765
outputs[1][0] = zz[1]
749-
outputs[2][0] = np.array(zz[2])
766+
outputs[2][0] = np.asarray(zz[2])
750767
outputs[3][0] = zz[3]
751768

752769

pytensor/tensor/slinalg.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,20 +491,24 @@ def make_node(self, x):
491491
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
492492
)
493493

494-
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
495-
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
496-
497-
L = tensor(shape=x.type.shape, dtype=x.type.dtype)
498-
U = tensor(shape=x.type.shape, dtype=x.type.dtype)
494+
if x.type.numpy_dtype.kind in "ibu":
495+
if x.type.numpy_dtype.itemsize <= 2:
496+
out_dtype = "float32"
497+
else:
498+
out_dtype = "float64"
499+
else:
500+
out_dtype = x.type.dtype
501+
L = tensor(shape=x.type.shape, dtype=out_dtype)
502+
U = tensor(shape=x.type.shape, dtype=out_dtype)
499503

500504
if self.permute_l:
501505
# In this case, L is actually P @ L
502506
return Apply(self, inputs=[x], outputs=[L, U])
503507
if self.p_indices:
504-
p_indices = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
508+
p_indices = tensor(shape=(x.type.shape[0],), dtype="int32")
505509
return Apply(self, inputs=[x], outputs=[p_indices, L, U])
506510

507-
P = tensor(shape=x.type.shape, dtype=p_dtype)
511+
P = tensor(shape=x.type.shape, dtype=out_dtype)
508512
return Apply(self, inputs=[x], outputs=[P, L, U])
509513

510514
def perform(self, node, inputs, outputs):

tests/tensor/test_nlinalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,12 +502,13 @@ def test_correct_solution(self):
502502
z = lscalar()
503503
b = lstsq(x, y, z)
504504
f = function([x, y, z], b)
505+
505506
TestMatrix1 = np.asarray([[2, 1], [3, 4]])
506507
TestMatrix2 = np.asarray([[17, 20], [43, 50]])
507508
TestScalar = np.asarray(1)
508-
f = function([x, y, z], b)
509-
m = f(TestMatrix1, TestMatrix2, TestScalar)
510-
assert np.allclose(TestMatrix2, np.dot(TestMatrix1, m[0]))
509+
m0, _, rank, _ = f(TestMatrix1, TestMatrix2, TestScalar)
510+
assert rank.dtype == "int32"
511+
assert np.allclose(TestMatrix2, np.dot(TestMatrix1, m0))
511512

512513
def test_wrong_coefficient_matrix(self):
513514
x = vector()

0 commit comments

Comments
 (0)