@@ -63,6 +63,41 @@ def test_random_updates(rng_ctor):
63
63
)
64
64
65
65
66
+ def test_random_updates_input_storage_order ():
67
+ """Test case described in issue #314.
68
+
69
+ This happened when we tried to update the input storage after we clone the shared RNG.
70
+ We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained
71
+ numpy arrays before the RNG value, which would fail the equality check.
72
+
73
+ """
74
+ pt_rng = RandomStream (1 )
75
+
76
+ batchshape = (3 , 1 , 4 , 4 )
77
+ inp_shared = pytensor .shared (np .zeros (batchshape ), name = "inp_shared" )
78
+
79
+ inp = at .tensor4 (name = "inp" )
80
+ inp_update = inp + pt_rng .normal (size = inp .shape , loc = 5 , scale = 1e-5 )
81
+
82
+ # This function replaces inp by input_shared in the update expression
83
+ # This is what caused the RNG to appear later than inp_shared in the input_storage
84
+ with pytest .warns (
85
+ UserWarning ,
86
+ match = r"The RandomType SharedVariables \[.+\] will not be used" ,
87
+ ):
88
+ fn = pytensor .function (
89
+ inputs = [],
90
+ outputs = [],
91
+ updates = {inp_shared : inp_update },
92
+ givens = {inp : inp_shared },
93
+ mode = "JAX" ,
94
+ )
95
+ fn ()
96
+ np .testing .assert_allclose (inp_shared .get_value (), 5 , rtol = 1e-3 )
97
+ fn ()
98
+ np .testing .assert_allclose (inp_shared .get_value (), 10 , rtol = 1e-3 )
99
+
100
+
66
101
@pytest .mark .parametrize (
67
102
"rv_op, dist_params, base_size, cdf_name, params_conv" ,
68
103
[
0 commit comments