Skip to content

feat(useless-rewrite): Useless rewrite from log1mexp(log1mexp(x)) to x #535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def local_exp_log(fgraph, node):
@register_specialize
@node_rewriter([Elemwise])
Copy link
Member

@ricardoV94 ricardoV94 Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a bit better, as it will only call the rewrite on nodes with these Ops

Suggested change
@node_rewriter([Elemwise])
@node_rewriter([exp, expm1, log1mexp])

def local_exp_log_nan_switch(fgraph, node):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
"""Rewrites of the kind exp(log...(x)) that require a `nan` switch."""
x = node.inputs[0]

if not isinstance(node.op, Elemwise):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed (can't select the return None below for the git suggestion)

Suggested change
if not isinstance(node.op, Elemwise):

Expand All @@ -330,47 +330,48 @@ def local_exp_log_nan_switch(fgraph, node):
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op

# Case for exp(log(x)) -> x
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
return [new_out]

# Case for exp(log1p(x)) -> x + 1
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype))
return [new_out]

# Case for expm1(log(x)) -> x - 1
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), sub(x, 1), np.asarray(np.nan, old_out.dtype))
return [new_out]
def prev_is(op_type):
return isinstance(prev_op, op_type)

# Case for expm1(log1p(x)) -> x
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, -1), x, np.asarray(np.nan, old_out.dtype))
return [new_out]
def node_is(op_type):
return isinstance(node_op, op_type)

# Case for exp(log1mexp(x)) -> 1 - exp(x)
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
def nan_switch(*, if_, substitute) -> list:
"""Reused inner function because these cases all have the same fallback."""
old_out = node.outputs[0]
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype))
nan_fallback = np.asarray(np.nan, old_out.dtype)
new_out = switch(if_, substitute, nan_fallback)
return [new_out]

# Case for expm1(log1mexp(x)) -> -exp(x)
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype))
return [new_out]
x = x.owner.inputs[0]

op_map = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat but I find the old form more readable tbh? Also if I ever have to debug this rewrite I would rather have the unrolled if/elses.

I am happy with the nan_switch_helper and nesting the if/else based on the outer Ops.

aes.Log: {
# Case for exp(log(x)) -> x
aes.Exp: (ge(x, 0), x),
# Case for expm1(log(x)) -> x - 1
aes.Expm1: (ge(x, 0), sub(x, 1)),
},
aes.Log1p: {
# Case for exp(log1p(x)) -> x + 1
aes.Exp: (ge(x, -1), add(1, x)),
# Case for expm1(log1p(x)) -> x
aes.Expm1: (ge(x, -1), x),
},
aes.Log1mexp: {
# Case for exp(log1mexp(x)) -> 1 - exp(x)
aes.Exp: (le(x, 0), sub(1, exp(x))),
# Case for expm1(log1mexp(x)) -> -exp(x)
aes.Expm1: (le(x, 0), neg(exp(x))),
# Case for log1mexp(log1mexp(x)) -> x
aes.Log1mexp: (ge(x, 0), x),
},
}

for prev_to_match, node_candidates in op_map.items():
for node_to_match, (inequality, substitute) in node_candidates.items():
if prev_is(prev_to_match) and node_is(node_to_match):
return nan_switch(if_=inequality, substitute=substitution)


@register_canonicalize
Expand Down