diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a814ffdf69..ccb48abf27 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -319,7 +319,7 @@ def local_exp_log(fgraph, node): @register_specialize @node_rewriter([Elemwise]) def local_exp_log_nan_switch(fgraph, node): - # Rewrites of the kind exp(log...(x)) that require a `nan` switch + """Rewrites of the kind exp(log...(x)) that require a `nan` switch.""" x = node.inputs[0] if not isinstance(node.op, Elemwise): @@ -330,47 +330,48 @@ def local_exp_log_nan_switch(fgraph, node): prev_op = x.owner.op.scalar_op node_op = node.op.scalar_op - # Case for exp(log(x)) -> x - if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype)) - return [new_out] - - # Case for exp(log1p(x)) -> x + 1 - if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype)) - return [new_out] - - # Case for expm1(log(x)) -> x - 1 - if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Expm1): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(ge(x, 0), sub(x, 1), np.asarray(np.nan, old_out.dtype)) - return [new_out] + def prev_is(op_type): + return isinstance(prev_op, op_type) - # Case for expm1(log1p(x)) -> x - if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Expm1): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(ge(x, -1), x, np.asarray(np.nan, old_out.dtype)) - return [new_out] + def node_is(op_type): + return isinstance(node_op, op_type) - # Case for exp(log1mexp(x)) -> 1 - exp(x) - if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp): - x = x.owner.inputs[0] + def nan_switch(*, if_, substitute) -> list: + """Reused inner function because these cases all have the same fallback.""" old_out = node.outputs[0] - new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype)) + nan_fallback = np.asarray(np.nan, old_out.dtype) + new_out = switch(if_, substitute, nan_fallback) return [new_out] - # Case for expm1(log1mexp(x)) -> -exp(x) - if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) - return [new_out] + x = x.owner.inputs[0] + + op_map = { + aes.Log: { + # Case for exp(log(x)) -> x + aes.Exp: (ge(x, 0), x), + # Case for expm1(log(x)) -> x - 1 + aes.Expm1: (ge(x, 0), sub(x, 1)), + }, + aes.Log1p: { + # Case for exp(log1p(x)) -> x + 1 + aes.Exp: (ge(x, -1), add(1, x)), + # Case for expm1(log1p(x)) -> x + aes.Expm1: (ge(x, -1), x), + }, + aes.Log1mexp: { + # Case for exp(log1mexp(x)) -> 1 - exp(x) + aes.Exp: (le(x, 0), sub(1, exp(x))), + # Case for expm1(log1mexp(x)) -> -exp(x) + aes.Expm1: (le(x, 0), neg(exp(x))), + # Case for log1mexp(log1mexp(x)) -> x + aes.Log1mexp: (ge(x, 0), x), + }, + } + + for prev_to_match, node_candidates in op_map.items(): + for node_to_match, (inequality, substitute) in node_candidates.items(): + if prev_is(prev_to_match) and node_is(node_to_match): + return nan_switch(if_=inequality, substitute=substitution) @register_canonicalize