Skip to content

Commit ac8f1f2

Browse files
committed
Intercept UserWarning on JAX random function tests
1 parent b54577d commit ac8f1f2

File tree

1 file changed

+49
-49
lines changed

1 file changed

+49
-49
lines changed

tests/link/jax/test_random.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@
2222
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2323

2424

25+
# turns all warnings into errors for this module
26+
pytestmark = pytest.mark.filterwarnings("error")
27+
28+
29+
def random_function(*args, **kwargs):
30+
with pytest.warns(
31+
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
32+
):
33+
return function(*args, **kwargs)
34+
35+
2536
def test_random_RandomStream():
2637
"""Two successive calls of a compiled graph using `RandomStream` should
2738
return different values.
@@ -30,11 +41,7 @@ def test_random_RandomStream():
3041
srng = RandomStream(seed=123)
3142
out = srng.normal() - srng.normal()
3243

33-
with pytest.warns(
34-
UserWarning,
35-
match=r"The RandomType SharedVariables \[.+\] will not be used",
36-
):
37-
fn = function([], out, mode=jax_mode)
44+
fn = random_function([], out, mode=jax_mode)
3845
jax_res_1 = fn()
3946
jax_res_2 = fn()
4047

@@ -47,13 +54,7 @@ def test_random_updates(rng_ctor):
4754
rng = shared(original_value, name="original_rng", borrow=False)
4855
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs
4956

50-
with pytest.warns(
51-
UserWarning,
52-
match=re.escape(
53-
"The RandomType SharedVariables [original_rng] will not be used"
54-
),
55-
):
56-
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
57+
f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
5758
assert f() != f()
5859

5960
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -81,17 +82,14 @@ def test_random_updates_input_storage_order():
8182

8283
# This function replaces inp by input_shared in the update expression
8384
# This is what caused the RNG to appear later than inp_shared in the input_storage
84-
with pytest.warns(
85-
UserWarning,
86-
match=r"The RandomType SharedVariables \[.+\] will not be used",
87-
):
88-
fn = pytensor.function(
89-
inputs=[],
90-
outputs=[],
91-
updates={inp_shared: inp_update},
92-
givens={inp: inp_shared},
93-
mode="JAX",
94-
)
85+
86+
fn = random_function(
87+
inputs=[],
88+
outputs=[],
89+
updates={inp_shared: inp_update},
90+
givens={inp: inp_shared},
91+
mode="JAX",
92+
)
9593
fn()
9694
np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3)
9795
fn()
@@ -455,7 +453,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
455453
else:
456454
rng = shared(np.random.RandomState(29402))
457455
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
458-
g_fn = function(dist_params, g, mode=jax_mode)
456+
g_fn = random_function(dist_params, g, mode=jax_mode)
459457
samples = g_fn(
460458
*[
461459
i.tag.test_value
@@ -479,7 +477,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
479477
def test_random_bernoulli(size):
480478
rng = shared(np.random.RandomState(123))
481479
g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
482-
g_fn = function([], g, mode=jax_mode)
480+
g_fn = random_function([], g, mode=jax_mode)
483481
samples = g_fn()
484482
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
485483

@@ -490,7 +488,7 @@ def test_random_mvnormal():
490488
mu = np.ones(4)
491489
cov = np.eye(4)
492490
g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
493-
g_fn = function([], g, mode=jax_mode)
491+
g_fn = random_function([], g, mode=jax_mode)
494492
samples = g_fn()
495493
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)
496494

@@ -505,7 +503,7 @@ def test_random_mvnormal():
505503
def test_random_dirichlet(parameter, size):
506504
rng = shared(np.random.RandomState(123))
507505
g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
508-
g_fn = function([], g, mode=jax_mode)
506+
g_fn = random_function([], g, mode=jax_mode)
509507
samples = g_fn()
510508
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)
511509

@@ -515,29 +513,29 @@ def test_random_choice():
515513
num_samples = 10000
516514
rng = shared(np.random.RandomState(123))
517515
g = at.random.choice(np.arange(4), size=num_samples, rng=rng)
518-
g_fn = function([], g, mode=jax_mode)
516+
g_fn = random_function([], g, mode=jax_mode)
519517
samples = g_fn()
520518
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
521519

522520
# `replace=False` produces unique results
523521
rng = shared(np.random.RandomState(123))
524522
g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng)
525-
g_fn = function([], g, mode=jax_mode)
523+
g_fn = random_function([], g, mode=jax_mode)
526524
samples = g_fn()
527525
assert len(np.unique(samples)) == 99
528526

529527
# We can pass an array with probabilities
530528
rng = shared(np.random.RandomState(123))
531529
g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
532-
g_fn = function([], g, mode=jax_mode)
530+
g_fn = random_function([], g, mode=jax_mode)
533531
samples = g_fn()
534532
np.testing.assert_allclose(samples, np.zeros(10))
535533

536534

537535
def test_random_categorical():
538536
rng = shared(np.random.RandomState(123))
539537
g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
540-
g_fn = function([], g, mode=jax_mode)
538+
g_fn = random_function([], g, mode=jax_mode)
541539
samples = g_fn()
542540
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
543541

@@ -546,7 +544,7 @@ def test_random_permutation():
546544
array = np.arange(4)
547545
rng = shared(np.random.RandomState(123))
548546
g = at.random.permutation(array, rng=rng)
549-
g_fn = function([], g, mode=jax_mode)
547+
g_fn = random_function([], g, mode=jax_mode)
550548
permuted = g_fn()
551549
with pytest.raises(AssertionError):
552550
np.testing.assert_allclose(array, permuted)
@@ -556,7 +554,7 @@ def test_random_geometric():
556554
rng = shared(np.random.RandomState(123))
557555
p = np.array([0.3, 0.7])
558556
g = at.random.geometric(p, size=(10_000, 2), rng=rng)
559-
g_fn = function([], g, mode=jax_mode)
557+
g_fn = random_function([], g, mode=jax_mode)
560558
samples = g_fn()
561559
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
562560
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
@@ -567,7 +565,7 @@ def test_negative_binomial():
567565
n = np.array([10, 40])
568566
p = np.array([0.3, 0.7])
569567
g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
570-
g_fn = function([], g, mode=jax_mode)
568+
g_fn = random_function([], g, mode=jax_mode)
571569
samples = g_fn()
572570
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
573571
np.testing.assert_allclose(
@@ -581,7 +579,7 @@ def test_binomial():
581579
n = np.array([10, 40])
582580
p = np.array([0.3, 0.7])
583581
g = at.random.binomial(n, p, size=(10_000, 2), rng=rng)
584-
g_fn = function([], g, mode=jax_mode)
582+
g_fn = random_function([], g, mode=jax_mode)
585583
samples = g_fn()
586584
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
587585
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
@@ -596,7 +594,7 @@ def test_beta_binomial():
596594
a = np.array([1.5, 13])
597595
b = np.array([0.5, 9])
598596
g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
599-
g_fn = function([], g, mode=jax_mode)
597+
g_fn = random_function([], g, mode=jax_mode)
600598
samples = g_fn()
601599
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
602600
np.testing.assert_allclose(
@@ -614,7 +612,7 @@ def test_multinomial():
614612
n = np.array([10, 40])
615613
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
616614
g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng)
617-
g_fn = function([], g, mode=jax_mode)
615+
g_fn = random_function([], g, mode=jax_mode)
618616
samples = g_fn()
619617
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
620618
np.testing.assert_allclose(
@@ -630,7 +628,7 @@ def test_vonmises_mu_outside_circle():
630628
mu = np.array([-30, 40])
631629
kappa = np.array([100, 10])
632630
g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
633-
g_fn = function([], g, mode=jax_mode)
631+
g_fn = random_function([], g, mode=jax_mode)
634632
samples = g_fn()
635633
np.testing.assert_allclose(
636634
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
@@ -676,7 +674,10 @@ def rng_fn(cls, rng, size):
676674
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
677675

678676
with pytest.raises(NotImplementedError):
679-
compare_jax_and_py(fgraph, [])
677+
with pytest.warns(
678+
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
679+
):
680+
compare_jax_and_py(fgraph, [])
680681

681682

682683
def test_random_custom_implementation():
@@ -707,7 +708,10 @@ def sample_fn(rng, size, dtype, *parameters):
707708
rng = shared(np.random.RandomState(123))
708709
out = nonexistentrv(rng=rng)
709710
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
710-
compare_jax_and_py(fgraph, [])
711+
with pytest.warns(
712+
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
713+
):
714+
compare_jax_and_py(fgraph, [])
711715

712716

713717
def test_random_concrete_shape():
@@ -724,19 +728,15 @@ def test_random_concrete_shape():
724728
rng = shared(np.random.RandomState(123))
725729
x_at = at.dmatrix()
726730
out = at.random.normal(0, 1, size=x_at.shape, rng=rng)
727-
jax_fn = function([x_at], out, mode=jax_mode)
731+
jax_fn = random_function([x_at], out, mode=jax_mode)
728732
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
729733

730734

731735
def test_random_concrete_shape_from_param():
732736
rng = shared(np.random.RandomState(123))
733737
x_at = at.dmatrix()
734738
out = at.random.normal(x_at, 1, rng=rng)
735-
with pytest.warns(
736-
UserWarning,
737-
match="The RandomType SharedVariables \[.+\] will not be used"
738-
):
739-
jax_fn = function([x_at], out, mode=jax_mode)
739+
jax_fn = random_function([x_at], out, mode=jax_mode)
740740
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
741741

742742

@@ -755,7 +755,7 @@ def test_random_concrete_shape_subtensor():
755755
rng = shared(np.random.RandomState(123))
756756
x_at = at.dmatrix()
757757
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
758-
jax_fn = function([x_at], out, mode=jax_mode)
758+
jax_fn = random_function([x_at], out, mode=jax_mode)
759759
assert jax_fn(np.ones((2, 3))).shape == (3,)
760760

761761

@@ -771,7 +771,7 @@ def test_random_concrete_shape_subtensor_tuple():
771771
rng = shared(np.random.RandomState(123))
772772
x_at = at.dmatrix()
773773
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
774-
jax_fn = function([x_at], out, mode=jax_mode)
774+
jax_fn = random_function([x_at], out, mode=jax_mode)
775775
assert jax_fn(np.ones((2, 3))).shape == (2,)
776776

777777

@@ -782,5 +782,5 @@ def test_random_concrete_shape_graph_input():
782782
rng = shared(np.random.RandomState(123))
783783
size_at = at.scalar()
784784
out = at.random.normal(0, 1, size=size_at, rng=rng)
785-
jax_fn = function([size_at], out, mode=jax_mode)
785+
jax_fn = random_function([size_at], out, mode=jax_mode)
786786
assert jax_fn(10).shape == (10,)

0 commit comments

Comments
 (0)