Skip to content

Commit 0675b19

Browse files
committed
PivotToPermutation: Stick with default int64 behavior
1 parent c56bc7e commit 0675b19

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

pytensor/link/numba/dispatch/linalg/decomposition/lu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313

1414
@numba_basic.numba_njit
15-
def _pivot_to_permutation(p, dtype):
16-
p_inv = np.arange(len(p)).astype(dtype)
15+
def _pivot_to_permutation(p):
16+
p_inv = np.arange(len(p))
1717
for i in range(len(p)):
1818
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
1919
return p_inv
@@ -29,7 +29,7 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
2929

3030
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
3131
IPIV = IPIV - 1
32-
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
32+
p_inv = _pivot_to_permutation(IPIV)
3333
perm = np.argsort(p_inv).astype("int32")
3434

3535
return perm, L, U

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,10 @@ def cholesky(a):
9797
@register_funcify_default_op_cache_key(PivotToPermutations)
9898
def pivot_to_permutation(op, node, **kwargs):
9999
inverse = op.inverse
100-
dtype = node.outputs[0].dtype
101100

102101
@numba_basic.numba_njit
103102
def numba_pivot_to_permutation(piv):
104-
p_inv = _pivot_to_permutation(piv, dtype)
103+
p_inv = _pivot_to_permutation(piv)
105104

106105
if inverse:
107106
return p_inv

0 commit comments

Comments
 (0)