Skip to content

Commit ed04105

Browse files
committed
Differentiate between None and empty shape
1 parent 909ad93 commit ed04105

File tree

7 files changed

+126
-114
lines changed

7 files changed

+126
-114
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_name_for_object,
1919
unique_name_generator,
2020
)
21+
from pytensor.tensor import NoneConst
2122
from pytensor.tensor.basic import get_vector_length
2223
from pytensor.tensor.random.type import RandomStateType
2324

@@ -98,8 +99,7 @@ def make_numba_random_fn(node, np_random_func):
9899
if not isinstance(node.inputs[0].type, RandomStateType):
99100
raise TypeError("Numba does not support NumPy `Generator`s")
100101

101-
tuple_size = int(get_vector_length(node.inputs[1]))
102-
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
102+
size = node.inputs[1]
103103

104104
# Make a broadcast-capable version of the Numba supported scalar sampling
105105
# function
@@ -115,8 +115,6 @@ def make_numba_random_fn(node, np_random_func):
115115
"np_random_func",
116116
"numba_vectorize",
117117
"to_fixed_tuple",
118-
"tuple_size",
119-
"size_dims",
120118
"rng",
121119
"size",
122120
"dtype",
@@ -152,7 +150,10 @@ def {bcast_fn_name}({bcast_fn_input_names}):
152150
"out_dtype": out_dtype,
153151
}
154152

155-
if tuple_size > 0:
153+
if not NoneConst.equals(size):
154+
tuple_size = int(get_vector_length(node.inputs[1]))
155+
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
156+
156157
random_fn_body = dedent(
157158
f"""
158159
size = to_fixed_tuple(size, tuple_size)
@@ -302,12 +303,15 @@ def body_fn(a):
302303
@numba_funcify.register(ptr.CategoricalRV)
303304
def numba_funcify_CategoricalRV(op, node, **kwargs):
304305
out_dtype = node.outputs[1].type.numpy_dtype
305-
size_len = int(get_vector_length(node.inputs[1]))
306+
size = node.inputs[1]
307+
none_size = NoneConst.equals(size)
308+
if not none_size:
309+
size_len = int(get_vector_length(size))
306310
p_ndim = node.inputs[-1].ndim
307311

308312
@numba_basic.numba_njit
309313
def categorical_rv(rng, size, dtype, p):
310-
if not size_len:
314+
if none_size:
311315
size_tpl = p.shape[:-1]
312316
else:
313317
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
@@ -333,22 +337,25 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
333337
out_dtype = node.outputs[1].type.numpy_dtype
334338
alphas_ndim = node.inputs[3].type.ndim
335339
neg_ind_shape_len = -alphas_ndim + 1
336-
size_len = int(get_vector_length(node.inputs[1]))
340+
size = node.inputs[1]
341+
none_size = NoneConst.equals(size)
342+
if not none_size:
343+
size_len = int(get_vector_length(size))
337344

338345
if alphas_ndim > 1:
339346

340347
@numba_basic.numba_njit
341348
def dirichlet_rv(rng, size, dtype, alphas):
342-
if size_len > 0:
349+
if none_size:
350+
samples_shape = alphas.shape
351+
else:
343352
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
344353
if (
345354
0 < alphas.ndim - 1 <= len(size_tpl)
346355
and size_tpl[neg_ind_shape_len:] != alphas.shape[:-1]
347356
):
348357
raise ValueError("Parameters shape and size do not match.")
349358
samples_shape = size_tpl + alphas.shape[-1:]
350-
else:
351-
samples_shape = alphas.shape
352359

353360
res = np.empty(samples_shape, dtype=out_dtype)
354361
alphas_bcast = np.broadcast_to(alphas, samples_shape)
@@ -362,7 +369,8 @@ def dirichlet_rv(rng, size, dtype, alphas):
362369

363370
@numba_basic.numba_njit
364371
def dirichlet_rv(rng, size, dtype, alphas):
365-
size = numba_ndarray.to_fixed_tuple(size, size_len)
372+
if size is not None:
373+
size = numba_ndarray.to_fixed_tuple(size, size_len)
366374
return (rng, np.random.dirichlet(alphas, size))
367375

368376
return dirichlet_rv

pytensor/tensor/random/op.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
as_tensor_variable,
1616
concatenate,
1717
constant,
18-
get_underlying_scalar_constant_value,
1918
get_vector_length,
2019
infer_static_shape,
2120
)
@@ -133,7 +132,7 @@ def __str__(self):
133132

134133
def _infer_shape(
135134
self,
136-
size: TensorVariable,
135+
size: Union[TensorVariable, NoneConst],
137136
dist_params: Sequence[TensorVariable],
138137
param_shapes: Optional[Sequence[tuple[Variable, ...]]] = None,
139138
) -> Union[TensorVariable, tuple[ScalarVariable, ...]]:
@@ -162,9 +161,9 @@ def _infer_shape(
162161
self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
163162
)
164163

165-
size_len = get_vector_length(size)
164+
if not NoneConst.equals(size):
165+
size_len = get_vector_length(size)
166166

167-
if size_len > 0:
168167
# Fail early when size is incompatible with parameters
169168
for i, (param, param_ndim_supp) in enumerate(
170169
zip(dist_params, self.ndims_params)
@@ -174,7 +173,7 @@ def _infer_shape(
174173
raise ValueError(
175174
f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n"
176175
f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. "
177-
f"Size length must be 0 or >= {param_batched_dims}"
176+
f"Size must be None or have length >= {param_batched_dims}"
178177
)
179178

180179
return tuple(size) + supp_shape
@@ -218,22 +217,12 @@ def extract_batch_shape(p, ps, n):
218217

219218
shape = batch_shape + supp_shape
220219

221-
if not shape:
222-
shape = constant([], dtype="int64")
223-
224220
return shape
225221

226222
def infer_shape(self, fgraph, node, input_shapes):
227223
_, size, _, *dist_params = node.inputs
228224
_, size_shape, _, *param_shapes = input_shapes
229225

230-
try:
231-
size_len = get_vector_length(size)
232-
except ValueError:
233-
size_len = get_underlying_scalar_constant_value(size_shape[0])
234-
235-
size = tuple(size[n] for n in range(size_len))
236-
237226
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
238227

239228
return [None, list(shape)]
@@ -313,12 +302,7 @@ def perform(self, node, inputs, outputs):
313302

314303
out_var = node.outputs[1]
315304

316-
# If `size == []`, that means no size is enforced, and NumPy is trusted
317-
# to draw the appropriate number of samples, NumPy uses `size=None` to
318-
# represent that. Otherwise, NumPy expects a tuple.
319-
if np.size(size) == 0:
320-
size = None
321-
else:
305+
if size is not None:
322306
size = tuple(size)
323307

324308
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng`
@@ -394,21 +378,21 @@ def vectorize_random_variable(
394378
# Need to make parameters implicit broadcasting explicit
395379
original_dist_params = node.inputs[3:]
396380
old_size = node.inputs[1]
397-
len_old_size = get_vector_length(old_size)
398381

399382
original_expanded_dist_params = explicit_expand_dims(
400-
original_dist_params, op.ndims_params, len_old_size
383+
original_dist_params, op.ndims_params, old_size
401384
)
402385
# We call vectorize_graph to automatically handle any new explicit expand_dims
403386
dist_params = vectorize_graph(
404387
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
405388
)
406389

407-
if len_old_size and equal_computations([old_size], [size]):
390+
if (not NoneConst.equals(size)) and equal_computations([old_size], [size]):
408391
# If the original RV had a size variable and a new one has not been provided,
409392
# we need to define a new size as the concatenation of the original size dimensions
410393
# and the novel ones implied by new broadcasted batched parameters dimensions.
411394
# We use the first broadcasted batch dimension for reference.
395+
len_old_size = get_vector_length(old_size)
412396
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0]
413397
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
414398
if new_param_ndim >= 0:

pytensor/tensor/random/rewriting/basic.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
88
from pytensor.scalar import integer_types
99
from pytensor.tensor import NoneConst
10-
from pytensor.tensor.basic import constant, get_vector_length
10+
from pytensor.tensor.basic import constant
1111
from pytensor.tensor.elemwise import DimShuffle
1212
from pytensor.tensor.extra_ops import broadcast_to
1313
from pytensor.tensor.random.op import RandomVariable
@@ -85,7 +85,7 @@ def local_rv_size_lift(fgraph, node):
8585

8686
dist_params = broadcast_params(dist_params, node.op.ndims_params)
8787

88-
if get_vector_length(size) > 0:
88+
if not NoneConst.equals(size):
8989
dist_params = [
9090
broadcast_to(
9191
p,
@@ -156,20 +156,17 @@ def local_dimshuffle_rv_lift(fgraph, node):
156156
if is_rv_used_in_graph(base_rv, node, fgraph):
157157
return False
158158

159-
batched_dims = rv.ndim - rv_op.ndim_supp
159+
batched_dims = rv.type.ndim - rv_op.ndim_supp
160160
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)
161161

162-
# Make size explicit
163-
missing_size_dims = batched_dims - get_vector_length(size)
164-
if missing_size_dims > 0:
165-
full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
166-
size = full_size[:missing_size_dims] + tuple(size)
167-
168-
# Update the size to reflect the DimShuffled dimensions
169-
new_size = [
170-
constant(1, dtype="int64") if o == "x" else size[o]
171-
for o in batched_dims_ds_order
172-
]
162+
if NoneConst.equals(size):
163+
new_size = NoneConst
164+
else:
165+
# Update the size to reflect the DimShuffled dimensions
166+
new_size = [
167+
constant(1, dtype="int64") if o == "x" else size[o]
168+
for o in batched_dims_ds_order
169+
]
173170

174171
# Updates the params to reflect the Dimshuffled dimensions
175172
new_dist_params = []

pytensor/tensor/random/utils.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import numpy as np
88

99
from pytensor.compile.sharedvalue import shared
10-
from pytensor.graph.basic import Constant, Variable
10+
from pytensor.graph.basic import Variable
1111
from pytensor.scalar import ScalarVariable
12-
from pytensor.tensor import get_vector_length
13-
from pytensor.tensor.basic import as_tensor_variable, cast, constant
12+
from pytensor.tensor import NoneConst, get_vector_length
13+
from pytensor.tensor.basic import as_tensor_variable, cast
1414
from pytensor.tensor.extra_ops import broadcast_to
1515
from pytensor.tensor.math import maximum
1616
from pytensor.tensor.shape import shape_padleft, specify_shape
@@ -124,20 +124,18 @@ def broadcast_params(params, ndims_params):
124124
def explicit_expand_dims(
125125
params: Sequence[TensorVariable],
126126
ndim_params: tuple[int],
127-
size_length: int = 0,
127+
size: Union[TensorVariable, NoneConst] = NoneConst,
128128
) -> list[TensorVariable]:
129129
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
130130

131131
batch_dims = [
132132
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
133133
]
134134

135-
if size_length:
136-
# NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does
137-
# See: https://github.com/pymc-devs/pytensor/issues/568
138-
max_batch_dims = size_length
139-
else:
135+
if NoneConst.equals(size):
140136
max_batch_dims = max(batch_dims)
137+
else:
138+
max_batch_dims = get_vector_length(size)
141139

142140
new_params = []
143141
for new_param, batch_dim in zip(params, batch_dims):
@@ -153,9 +151,10 @@ def normalize_size_param(
153151
size: Optional[Union[int, np.ndarray, Variable, Sequence]],
154152
) -> Variable:
155153
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
156-
if size is None:
157-
size = constant([], dtype="int64")
158-
elif isinstance(size, int):
154+
if size is None or NoneConst.equals(size):
155+
return NoneConst
156+
157+
if isinstance(size, int):
159158
size = as_tensor_variable([size], ndim=1)
160159
elif not isinstance(size, (np.ndarray, Variable, Sequence)):
161160
raise TypeError(
@@ -164,7 +163,7 @@ def normalize_size_param(
164163
else:
165164
size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64")
166165

167-
if not isinstance(size, Constant):
166+
if size.type.shape == (None,):
168167
# This should help ensure that the length of non-constant `size`s
169168
# will be available after certain types of cloning (e.g. the kind
170169
# `Scan` performs)

0 commit comments

Comments
 (0)