Skip to content

feat(useless-rewrite): Useless rewrite from log1mexp(log1mexp(x)) to x #535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed

feat(useless-rewrite): Useless rewrite from log1mexp(log1mexp(x)) to x #535

wants to merge 4 commits into from

Conversation

lmmx
Copy link
Contributor

@lmmx lmmx commented Dec 6, 2023

Motivation for these changes

Picking up this ticket on the PyData Global 2023 OSS sprint 🏃

The motivating issue concerned an issue using PyMC's Censored functionality which was traced back to the need for a "useless rewrite":

  • replacing a $log(1 - exp(x))$ within another $log(1 - exp(x))$
    • i.e. a $log(1 - exp(log(1 - exp(x))))$
  • by just an $x$.

Implementation details

  1. 🧹 refactor(nan switch): DRY out nan switch rewrite function, easier to follow important parts 🧹
  • The first change is in preparation for adding a new case, which is to simplify the case handling (avoid repetition).
  • This is achieved by putting an "inner function" (function within a function) that captures the x and node from the function body in its scope, meaning we don't need to pass them as parameters, so the body of each case becomes simpler
  • Every case has a "nan switch", so this trick lets us avoid repeating ourselves but retaining clarity about the variables we're using.
  1. ✍️ Add new useless rewrite case ✍️
  • The previous operation and the node operation are both going to be log1mexp for this case
  • The condition for the case is that x >= 0 (confirm?)

Checklist

Major / Breaking Changes

  • N/A

New features

  • "Useless rewrite" to optimise $log(1 - exp(log(1 - exp(x))))$ into just $x$

Bugfixes

Documentation

  • ...

Maintenance

  • Took an opportunity to tidy up the 'nan switch' cases into a dict, which is clearer to read than the repetitive if blocks

@lmmx lmmx marked this pull request as ready for review December 7, 2023 13:23
@lmmx
Copy link
Contributor Author

lmmx commented Dec 7, 2023

I think to finish this I need to add a test here

@pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_log1mexp(self, exp_op):
# exp(log1mexp(x)) -> switch(x <= 0, 1 - exp(x), nan)
# expm1(log1mexp(x)) -> switch(x <= 0, - exp(x), nan)
data_valid = -np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case
data_invalid = data_valid + 1
x = fmatrix()
f = function([x], exp_op(log1mexp(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(
node.op.scalar_op, (aes.Log, aes.Log1p, aes.Log1mexp, aes.Expm1)
)
]
assert len(ops_graph) == 0
if exp_op == exp:
expected = 1 - np.exp(data_valid)
else:
expected = -np.exp(data_valid)
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))

@@ -319,7 +319,7 @@ def local_exp_log(fgraph, node):
@register_specialize
@node_rewriter([Elemwise])
Copy link
Member

@ricardoV94 ricardoV94 Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a bit better, as it will only call the rewrite on nodes with these Ops

Suggested change
@node_rewriter([Elemwise])
@node_rewriter([exp, expm1, log1mexp])

@@ -319,7 +319,7 @@ def local_exp_log(fgraph, node):
@register_specialize
@node_rewriter([Elemwise])
def local_exp_log_nan_switch(fgraph, node):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
"""Rewrites of the kind exp(log...(x)) that require a `nan` switch."""
x = node.inputs[0]

if not isinstance(node.op, Elemwise):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed (can't select the return None below for the git suggestion)

Suggested change
if not isinstance(node.op, Elemwise):

return [new_out]
x = x.owner.inputs[0]

op_map = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat but I find the old form more readable tbh? Also if I ever have to debug this rewrite I would rather have the unrolled if/elses.

I am happy with the nan_switch_helper and nesting the if/else based on the outer Ops.

@ricardoV94
Copy link
Member

I think to finish this I need to add a test here

@pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_log1mexp(self, exp_op):
# exp(log1mexp(x)) -> switch(x <= 0, 1 - exp(x), nan)
# expm1(log1mexp(x)) -> switch(x <= 0, - exp(x), nan)
data_valid = -np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case
data_invalid = data_valid + 1
x = fmatrix()
f = function([x], exp_op(log1mexp(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(
node.op.scalar_op, (aes.Log, aes.Log1p, aes.Log1mexp, aes.Expm1)
)
]
assert len(ops_graph) == 0
if exp_op == exp:
expected = 1 - np.exp(data_valid)
else:
expected = -np.exp(data_valid)
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))

Sounds about right

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 8, 2023

The condition for the case is that x >= 0 (confirm?)

The valid condition is that x <= 0 for which the inner log1mexp is defined. x > 0 shoud yield nan, as that would lead to taking the log of a negative number.

@lmmx lmmx closed this by deleting the head repository Feb 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimize log1mexp(log1mexp(x)) -> x
2 participants