22
22
from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
23
23
24
24
25
+ # turns all warnings into errors for this module
26
+ pytestmark = pytest .mark .filterwarnings ("error" )
27
+
28
+
29
+ def random_function (* args , ** kwargs ):
30
+ with pytest .warns (
31
+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
32
+ ):
33
+ return function (* args , ** kwargs )
34
+
35
+
25
36
def test_random_RandomStream ():
26
37
"""Two successive calls of a compiled graph using `RandomStream` should
27
38
return different values.
@@ -30,11 +41,7 @@ def test_random_RandomStream():
30
41
srng = RandomStream (seed = 123 )
31
42
out = srng .normal () - srng .normal ()
32
43
33
- with pytest .warns (
34
- UserWarning ,
35
- match = r"The RandomType SharedVariables \[.+\] will not be used" ,
36
- ):
37
- fn = function ([], out , mode = jax_mode )
44
+ fn = random_function ([], out , mode = jax_mode )
38
45
jax_res_1 = fn ()
39
46
jax_res_2 = fn ()
40
47
@@ -47,13 +54,7 @@ def test_random_updates(rng_ctor):
47
54
rng = shared (original_value , name = "original_rng" , borrow = False )
48
55
next_rng , x = at .random .normal (name = "x" , rng = rng ).owner .outputs
49
56
50
- with pytest .warns (
51
- UserWarning ,
52
- match = re .escape (
53
- "The RandomType SharedVariables [original_rng] will not be used"
54
- ),
55
- ):
56
- f = pytensor .function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
57
+ f = random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
57
58
assert f () != f ()
58
59
59
60
# Check that original rng variable content was not overwritten when calling jax_typify
@@ -81,17 +82,14 @@ def test_random_updates_input_storage_order():
81
82
82
83
# This function replaces inp by input_shared in the update expression
83
84
# This is what caused the RNG to appear later than inp_shared in the input_storage
84
- with pytest .warns (
85
- UserWarning ,
86
- match = r"The RandomType SharedVariables \[.+\] will not be used" ,
87
- ):
88
- fn = pytensor .function (
89
- inputs = [],
90
- outputs = [],
91
- updates = {inp_shared : inp_update },
92
- givens = {inp : inp_shared },
93
- mode = "JAX" ,
94
- )
85
+
86
+ fn = random_function (
87
+ inputs = [],
88
+ outputs = [],
89
+ updates = {inp_shared : inp_update },
90
+ givens = {inp : inp_shared },
91
+ mode = "JAX" ,
92
+ )
95
93
fn ()
96
94
np .testing .assert_allclose (inp_shared .get_value (), 5 , rtol = 1e-3 )
97
95
fn ()
@@ -455,7 +453,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
455
453
else :
456
454
rng = shared (np .random .RandomState (29402 ))
457
455
g = rv_op (* dist_params , size = (10_000 ,) + base_size , rng = rng )
458
- g_fn = function (dist_params , g , mode = jax_mode )
456
+ g_fn = random_function (dist_params , g , mode = jax_mode )
459
457
samples = g_fn (
460
458
* [
461
459
i .tag .test_value
@@ -479,7 +477,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
479
477
def test_random_bernoulli (size ):
480
478
rng = shared (np .random .RandomState (123 ))
481
479
g = at .random .bernoulli (0.5 , size = (1000 ,) + size , rng = rng )
482
- g_fn = function ([], g , mode = jax_mode )
480
+ g_fn = random_function ([], g , mode = jax_mode )
483
481
samples = g_fn ()
484
482
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
485
483
@@ -490,7 +488,7 @@ def test_random_mvnormal():
490
488
mu = np .ones (4 )
491
489
cov = np .eye (4 )
492
490
g = at .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
493
- g_fn = function ([], g , mode = jax_mode )
491
+ g_fn = random_function ([], g , mode = jax_mode )
494
492
samples = g_fn ()
495
493
np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
496
494
@@ -505,7 +503,7 @@ def test_random_mvnormal():
505
503
def test_random_dirichlet (parameter , size ):
506
504
rng = shared (np .random .RandomState (123 ))
507
505
g = at .random .dirichlet (parameter , size = (1000 ,) + size , rng = rng )
508
- g_fn = function ([], g , mode = jax_mode )
506
+ g_fn = random_function ([], g , mode = jax_mode )
509
507
samples = g_fn ()
510
508
np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
511
509
@@ -515,29 +513,29 @@ def test_random_choice():
515
513
num_samples = 10000
516
514
rng = shared (np .random .RandomState (123 ))
517
515
g = at .random .choice (np .arange (4 ), size = num_samples , rng = rng )
518
- g_fn = function ([], g , mode = jax_mode )
516
+ g_fn = random_function ([], g , mode = jax_mode )
519
517
samples = g_fn ()
520
518
np .testing .assert_allclose (np .sum (samples == 3 ) / num_samples , 0.25 , 2 )
521
519
522
520
# `replace=False` produces unique results
523
521
rng = shared (np .random .RandomState (123 ))
524
522
g = at .random .choice (np .arange (100 ), replace = False , size = 99 , rng = rng )
525
- g_fn = function ([], g , mode = jax_mode )
523
+ g_fn = random_function ([], g , mode = jax_mode )
526
524
samples = g_fn ()
527
525
assert len (np .unique (samples )) == 99
528
526
529
527
# We can pass an array with probabilities
530
528
rng = shared (np .random .RandomState (123 ))
531
529
g = at .random .choice (np .arange (3 ), p = np .array ([1.0 , 0.0 , 0.0 ]), size = 10 , rng = rng )
532
- g_fn = function ([], g , mode = jax_mode )
530
+ g_fn = random_function ([], g , mode = jax_mode )
533
531
samples = g_fn ()
534
532
np .testing .assert_allclose (samples , np .zeros (10 ))
535
533
536
534
537
535
def test_random_categorical ():
538
536
rng = shared (np .random .RandomState (123 ))
539
537
g = at .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
540
- g_fn = function ([], g , mode = jax_mode )
538
+ g_fn = random_function ([], g , mode = jax_mode )
541
539
samples = g_fn ()
542
540
np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
543
541
@@ -546,7 +544,7 @@ def test_random_permutation():
546
544
array = np .arange (4 )
547
545
rng = shared (np .random .RandomState (123 ))
548
546
g = at .random .permutation (array , rng = rng )
549
- g_fn = function ([], g , mode = jax_mode )
547
+ g_fn = random_function ([], g , mode = jax_mode )
550
548
permuted = g_fn ()
551
549
with pytest .raises (AssertionError ):
552
550
np .testing .assert_allclose (array , permuted )
@@ -556,7 +554,7 @@ def test_random_geometric():
556
554
rng = shared (np .random .RandomState (123 ))
557
555
p = np .array ([0.3 , 0.7 ])
558
556
g = at .random .geometric (p , size = (10_000 , 2 ), rng = rng )
559
- g_fn = function ([], g , mode = jax_mode )
557
+ g_fn = random_function ([], g , mode = jax_mode )
560
558
samples = g_fn ()
561
559
np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
562
560
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -567,7 +565,7 @@ def test_negative_binomial():
567
565
n = np .array ([10 , 40 ])
568
566
p = np .array ([0.3 , 0.7 ])
569
567
g = at .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
570
- g_fn = function ([], g , mode = jax_mode )
568
+ g_fn = random_function ([], g , mode = jax_mode )
571
569
samples = g_fn ()
572
570
np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
573
571
np .testing .assert_allclose (
@@ -581,7 +579,7 @@ def test_binomial():
581
579
n = np .array ([10 , 40 ])
582
580
p = np .array ([0.3 , 0.7 ])
583
581
g = at .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
584
- g_fn = function ([], g , mode = jax_mode )
582
+ g_fn = random_function ([], g , mode = jax_mode )
585
583
samples = g_fn ()
586
584
np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
587
585
np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -596,7 +594,7 @@ def test_beta_binomial():
596
594
a = np .array ([1.5 , 13 ])
597
595
b = np .array ([0.5 , 9 ])
598
596
g = at .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
599
- g_fn = function ([], g , mode = jax_mode )
597
+ g_fn = random_function ([], g , mode = jax_mode )
600
598
samples = g_fn ()
601
599
np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
602
600
np .testing .assert_allclose (
@@ -614,7 +612,7 @@ def test_multinomial():
614
612
n = np .array ([10 , 40 ])
615
613
p = np .array ([[0.3 , 0.7 , 0.0 ], [0.1 , 0.4 , 0.5 ]])
616
614
g = at .random .multinomial (n , p , size = (10_000 , 2 ), rng = rng )
617
- g_fn = function ([], g , mode = jax_mode )
615
+ g_fn = random_function ([], g , mode = jax_mode )
618
616
samples = g_fn ()
619
617
np .testing .assert_allclose (samples .mean (axis = 0 ), n [..., None ] * p , rtol = 0.1 )
620
618
np .testing .assert_allclose (
@@ -630,7 +628,7 @@ def test_vonmises_mu_outside_circle():
630
628
mu = np .array ([- 30 , 40 ])
631
629
kappa = np .array ([100 , 10 ])
632
630
g = at .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
633
- g_fn = function ([], g , mode = jax_mode )
631
+ g_fn = random_function ([], g , mode = jax_mode )
634
632
samples = g_fn ()
635
633
np .testing .assert_allclose (
636
634
samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -676,7 +674,10 @@ def rng_fn(cls, rng, size):
676
674
fgraph = FunctionGraph ([out .owner .inputs [0 ]], [out ], clone = False )
677
675
678
676
with pytest .raises (NotImplementedError ):
679
- compare_jax_and_py (fgraph , [])
677
+ with pytest .warns (
678
+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
679
+ ):
680
+ compare_jax_and_py (fgraph , [])
680
681
681
682
682
683
def test_random_custom_implementation ():
@@ -707,7 +708,10 @@ def sample_fn(rng, size, dtype, *parameters):
707
708
rng = shared (np .random .RandomState (123 ))
708
709
out = nonexistentrv (rng = rng )
709
710
fgraph = FunctionGraph ([out .owner .inputs [0 ]], [out ], clone = False )
710
- compare_jax_and_py (fgraph , [])
711
+ with pytest .warns (
712
+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
713
+ ):
714
+ compare_jax_and_py (fgraph , [])
711
715
712
716
713
717
def test_random_concrete_shape ():
@@ -724,19 +728,15 @@ def test_random_concrete_shape():
724
728
rng = shared (np .random .RandomState (123 ))
725
729
x_at = at .dmatrix ()
726
730
out = at .random .normal (0 , 1 , size = x_at .shape , rng = rng )
727
- jax_fn = function ([x_at ], out , mode = jax_mode )
731
+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
728
732
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
729
733
730
734
731
735
def test_random_concrete_shape_from_param ():
732
736
rng = shared (np .random .RandomState (123 ))
733
737
x_at = at .dmatrix ()
734
738
out = at .random .normal (x_at , 1 , rng = rng )
735
- with pytest .warns (
736
- UserWarning ,
737
- match = "The RandomType SharedVariables \[.+\] will not be used"
738
- ):
739
- jax_fn = function ([x_at ], out , mode = jax_mode )
739
+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
740
740
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
741
741
742
742
@@ -755,7 +755,7 @@ def test_random_concrete_shape_subtensor():
755
755
rng = shared (np .random .RandomState (123 ))
756
756
x_at = at .dmatrix ()
757
757
out = at .random .normal (0 , 1 , size = x_at .shape [1 ], rng = rng )
758
- jax_fn = function ([x_at ], out , mode = jax_mode )
758
+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
759
759
assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
760
760
761
761
@@ -771,7 +771,7 @@ def test_random_concrete_shape_subtensor_tuple():
771
771
rng = shared (np .random .RandomState (123 ))
772
772
x_at = at .dmatrix ()
773
773
out = at .random .normal (0 , 1 , size = (x_at .shape [0 ],), rng = rng )
774
- jax_fn = function ([x_at ], out , mode = jax_mode )
774
+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
775
775
assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
776
776
777
777
@@ -782,5 +782,5 @@ def test_random_concrete_shape_graph_input():
782
782
rng = shared (np .random .RandomState (123 ))
783
783
size_at = at .scalar ()
784
784
out = at .random .normal (0 , 1 , size = size_at , rng = rng )
785
- jax_fn = function ([size_at ], out , mode = jax_mode )
785
+ jax_fn = random_function ([size_at ], out , mode = jax_mode )
786
786
assert jax_fn (10 ).shape == (10 ,)
0 commit comments