Skip to content

Commit b54577d

Browse files
committed
Fix bug in JAX cloning of RNG shared variables
1 parent 4ae7133 commit b54577d

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

pytensor/link/jax/linker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
4444
new_inp_storage = [new_inp.get_value(borrow=True)]
4545
storage_map[new_inp] = new_inp_storage
4646
old_inp_storage = storage_map.pop(old_inp)
47-
input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
47+
for input_storage_idx, input_storage_item in enumerate(input_storage):
48+
# We have to establish equality based on identity because input_storage may contain numpy arrays
49+
if input_storage_item is old_inp_storage:
50+
break
51+
input_storage[input_storage_idx] = new_inp_storage
4852
fgraph.remove_input(
4953
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
5054
)

tests/link/jax/test_random.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,41 @@ def test_random_updates(rng_ctor):
6363
)
6464

6565

66+
def test_random_updates_input_storage_order():
67+
"""Test case described in issue #314.
68+
69+
This happened when we tried to update the input storage after we clone the shared RNG.
70+
We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained
71+
numpy arrays before the RNG value, which would fail the equality check.
72+
73+
"""
74+
pt_rng = RandomStream(1)
75+
76+
batchshape = (3, 1, 4, 4)
77+
inp_shared = pytensor.shared(np.zeros(batchshape), name="inp_shared")
78+
79+
inp = at.tensor4(name="inp")
80+
inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5)
81+
82+
# This function replaces inp by input_shared in the update expression
83+
# This is what caused the RNG to appear later than inp_shared in the input_storage
84+
with pytest.warns(
85+
UserWarning,
86+
match=r"The RandomType SharedVariables \[.+\] will not be used",
87+
):
88+
fn = pytensor.function(
89+
inputs=[],
90+
outputs=[],
91+
updates={inp_shared: inp_update},
92+
givens={inp: inp_shared},
93+
mode="JAX",
94+
)
95+
fn()
96+
np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3)
97+
fn()
98+
np.testing.assert_allclose(inp_shared.get_value(), 10, rtol=1e-3)
99+
100+
66101
@pytest.mark.parametrize(
67102
"rv_op, dist_params, base_size, cdf_name, params_conv",
68103
[

0 commit comments

Comments
 (0)