-
Notifications
You must be signed in to change notification settings - Fork 135
feat(useless-rewrite): Useless rewrite from log1mexp(log1mexp(x))
to x
#535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8a82ba7
65b2a00
a5d1827
c10c376
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -319,7 +319,7 @@ def local_exp_log(fgraph, node): | |||
@register_specialize | ||||
@node_rewriter([Elemwise]) | ||||
def local_exp_log_nan_switch(fgraph, node): | ||||
# Rewrites of the kind exp(log...(x)) that require a `nan` switch | ||||
"""Rewrites of the kind exp(log...(x)) that require a `nan` switch.""" | ||||
x = node.inputs[0] | ||||
|
||||
if not isinstance(node.op, Elemwise): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed (can't select the return None below for the git suggestion)
Suggested change
|
||||
|
@@ -330,47 +330,48 @@ def local_exp_log_nan_switch(fgraph, node): | |||
prev_op = x.owner.op.scalar_op | ||||
node_op = node.op.scalar_op | ||||
|
||||
# Case for exp(log(x)) -> x | ||||
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp): | ||||
x = x.owner.inputs[0] | ||||
old_out = node.outputs[0] | ||||
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype)) | ||||
return [new_out] | ||||
|
||||
# Case for exp(log1p(x)) -> x + 1 | ||||
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp): | ||||
x = x.owner.inputs[0] | ||||
old_out = node.outputs[0] | ||||
new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype)) | ||||
return [new_out] | ||||
|
||||
# Case for expm1(log(x)) -> x - 1 | ||||
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Expm1): | ||||
x = x.owner.inputs[0] | ||||
old_out = node.outputs[0] | ||||
new_out = switch(ge(x, 0), sub(x, 1), np.asarray(np.nan, old_out.dtype)) | ||||
return [new_out] | ||||
def prev_is(op_type): | ||||
return isinstance(prev_op, op_type) | ||||
|
||||
# Case for expm1(log1p(x)) -> x | ||||
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Expm1): | ||||
x = x.owner.inputs[0] | ||||
old_out = node.outputs[0] | ||||
new_out = switch(ge(x, -1), x, np.asarray(np.nan, old_out.dtype)) | ||||
return [new_out] | ||||
def node_is(op_type): | ||||
return isinstance(node_op, op_type) | ||||
|
||||
# Case for exp(log1mexp(x)) -> 1 - exp(x) | ||||
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp): | ||||
x = x.owner.inputs[0] | ||||
def nan_switch(*, if_, substitute) -> list: | ||||
"""Reused inner function because these cases all have the same fallback.""" | ||||
old_out = node.outputs[0] | ||||
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype)) | ||||
nan_fallback = np.asarray(np.nan, old_out.dtype) | ||||
new_out = switch(if_, substitute, nan_fallback) | ||||
return [new_out] | ||||
|
||||
# Case for expm1(log1mexp(x)) -> -exp(x) | ||||
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1): | ||||
x = x.owner.inputs[0] | ||||
old_out = node.outputs[0] | ||||
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) | ||||
return [new_out] | ||||
x = x.owner.inputs[0] | ||||
|
||||
op_map = { | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is neat but I find the old form more readable tbh? Also if I ever have to debug this rewrite I would rather have the unrolled if/elses. I am happy with the |
||||
aes.Log: { | ||||
# Case for exp(log(x)) -> x | ||||
aes.Exp: (ge(x, 0), x), | ||||
# Case for expm1(log(x)) -> x - 1 | ||||
aes.Expm1: (ge(x, 0), sub(x, 1)), | ||||
}, | ||||
aes.Log1p: { | ||||
# Case for exp(log1p(x)) -> x + 1 | ||||
aes.Exp: (ge(x, -1), add(1, x)), | ||||
# Case for expm1(log1p(x)) -> x | ||||
aes.Expm1: (ge(x, -1), x), | ||||
}, | ||||
aes.Log1mexp: { | ||||
# Case for exp(log1mexp(x)) -> 1 - exp(x) | ||||
aes.Exp: (le(x, 0), sub(1, exp(x))), | ||||
# Case for expm1(log1mexp(x)) -> -exp(x) | ||||
aes.Expm1: (le(x, 0), neg(exp(x))), | ||||
# Case for log1mexp(log1mexp(x)) -> x | ||||
aes.Log1mexp: (ge(x, 0), x), | ||||
}, | ||||
} | ||||
|
||||
for prev_to_match, node_candidates in op_map.items(): | ||||
for node_to_match, (inequality, substitute) in node_candidates.items(): | ||||
if prev_is(prev_to_match) and node_is(node_to_match): | ||||
return nan_switch(if_=inequality, substitute=substitution) | ||||
|
||||
|
||||
@register_canonicalize | ||||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a bit better, as it will only call the rewrite on nodes with these Ops