Skip to content

Commit 18cd693

Browse files
brandonwillardricardoV94
authored andcommitted
Remove unnecessary graph_inputs usage in OpFromGraph
1 parent 556816f commit 18cd693

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

pytensor/compile/builders.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,21 +344,23 @@ def __init__(
344344
if isinstance(i, SharedVariable):
345345
raise TypeError(f"SharedVariables not allowed as inputs; {i}")
346346

347-
for var in graph_inputs(outputs, inputs):
348-
if var not in inputs and not isinstance(var, (Constant, SharedVariable)):
349-
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
350-
351347
if "updates" in kwargs or "givens" in kwargs:
352-
raise NotImplementedError("Updates and givens are not allowed here")
348+
raise NotImplementedError("Updates and givens are not supported")
353349

354350
self.is_inline = inline
355351

356-
# To correctly support shared variables the inner fct should
357-
# not see them. Otherwise there is a problem with the gradient.
358352
self.shared_inputs = []
359-
for var in graph_inputs(outputs):
353+
inner_graph_inputs = graph_inputs(outputs, inputs)
354+
for var in inner_graph_inputs:
360355
if isinstance(var, SharedVariable):
356+
# To correctly support shared variables the inner-graph should
357+
# not see them; otherwise, there will be problems with
358+
# gradients.
359+
# That's why we collect the shared variables and replace them
360+
# with dummies.
361361
self.shared_inputs.append(var)
362+
elif var not in inputs and not isinstance(var, Constant):
363+
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
362364

363365
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
364366

0 commit comments

Comments
 (0)