Skip to content

Commit 2c3504c

Browse files
committed
Workaround numba RandomState bug
1 parent 0e22695 commit 2c3504c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def body_fn(a):
312312
def numba_funcify_CategoricalRV(op, node, **kwargs):
313313
out_dtype = node.outputs[1].type.numpy_dtype
314314
size_len = int(get_vector_length(node.inputs[1]))
315+
p_ndim = node.inputs[-1].ndim
315316

316317
@numba_basic.numba_njit
317318
def categorical_rv(rng, size, dtype, p):
@@ -321,7 +322,11 @@ def categorical_rv(rng, size, dtype, p):
321322
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
322323
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
323324

324-
unif_samples = np.random.uniform(0, 1, size_tpl)
325+
# Workaround https://github.com/numba/numba/issues/8975
326+
if not size_len and p_ndim == 1:
327+
unif_samples = np.asarray(np.random.uniform(0, 1))
328+
else:
329+
unif_samples = np.random.uniform(0, 1, size_tpl)
325330

326331
res = np.empty(size_tpl, dtype=out_dtype)
327332
for idx in np.ndindex(*size_tpl):

0 commit comments

Comments
 (0)