Skip to content

Moved logsumexp #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2840,32 +2840,7 @@ def logaddexp(*xs):
return log(add(*[exp(x) for x in xs]))


def logsumexp(x, axis=None, keepdims=False):
"""Compute the log of the sum of exponentials of input elements.

See ``scipy.special.logsumexp``.

Parameters
----------
x : symbolic tensor
Input

axis : None or int or tuple of ints, optional
Axis or axes over which the sum is taken. By default axis is None,
and all elements are summed.

keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the
result as dimensions with size one. With this option, the result will
broadcast correctly against the original array.

Returns
-------
tensor

"""

return log(sum(exp(x), axis=axis, keepdims=keepdims))


class MatMul(Op):
Expand Down Expand Up @@ -3130,6 +3105,5 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
"ptp",
"power",
"logaddexp",
"logsumexp",
"hyp2f1",
]
52 changes: 0 additions & 52 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@
log,
log1mexp,
log1p,
makeKeepDims,
)
from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum, mul, neg
from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import (
Expand Down Expand Up @@ -2422,56 +2420,6 @@ def local_log_add_exp(fgraph, node):
return [ret]


@register_stabilize
@register_specialize
@node_rewriter([log])
def local_log_sum_exp(fgraph, node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))

if node.op != log:
return

sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, DimShuffle):
dimshuffle_op = sum_node.op
sum_node = sum_node.inputs[0].owner
else:
dimshuffle_op = None

if not sum_node or not isinstance(sum_node.op, Sum):
return

exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
if not exp_node or not (
isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, aes.Exp)
):
return

pre_exp = exp_node.inputs[0]
max_pre_exp = at_max(pre_exp, axis=axis)
max_pre_exp_keepdims = makeKeepDims(pre_exp, max_pre_exp, axis)

# Do not offset when max_pre = -np.inf, to avoid nan in the output
# Switch statement is placed directly inside sum to break the self-symmetry
# of the returned output (otherwise the rewrite would not stabilize)
ret = max_pre_exp + log(
at_sum(
switch(
isinf(max_pre_exp_keepdims),
exp(max_pre_exp_keepdims),
exp(pre_exp - max_pre_exp_keepdims),
),
axis=axis,
),
)

# Restore the dimshuffle op, if any.
if dimshuffle_op:
ret = dimshuffle_op(ret)

return [ret]


def add_calculate(num, denum, aslist=False, out_type=None):
# TODO: make sure that this function and mul_calculate are similar
Expand Down
94 changes: 90 additions & 4 deletions pytensor/tensor/special.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,99 @@
import warnings
from textwrap import dedent

import numpy as np
import scipy

import pytensor.scalar.basic as aes
from pytensor.graph.basic import Apply
from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import gamma, neg, sum
from pytensor.tensor.basic import as_tensor_variable, switch
from pytensor.tensor.math import gamma, neg, sum, Sum, makeKeepDims, isinf, log, exp
from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,

)
from pytensor.graph.rewriting.basic import node_rewriter


def logsumexp(x, axis=None, keepdims=False):
"""Compute the log of the sum of exponentials of input elements.

See ``scipy.special.logsumexp``.

Parameters
----------
x : symbolic tensor
Input

axis : None or int or tuple of ints, optional
Axis or axes over which the sum is taken. By default axis is None,
and all elements are summed.

keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the
result as dimensions with size one. With this option, the result will
broadcast correctly against the original array.

Returns
-------
tensor

"""

return log(sum(exp(x), axis=axis, keepdims=keepdims))

@register_stabilize
@register_specialize
@node_rewriter([log])
def local_log_sum_exp(fgraph, node):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))

if node.op != log:
return

sum_node = node.inputs[0].owner
# If the sum has keepdims=True, there might be a dimshuffle
if sum_node and isinstance(sum_node.op, DimShuffle):
dimshuffle_op = sum_node.op
sum_node = sum_node.inputs[0].owner
else:
dimshuffle_op = None

if not sum_node or not isinstance(sum_node.op, Sum):
return

exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
if not exp_node or not (
isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, aes.Exp)
):
return

pre_exp = exp_node.inputs[0]
max_pre_exp = at_max(pre_exp, axis=axis)
max_pre_exp_keepdims = makeKeepDims(pre_exp, max_pre_exp, axis)

# Do not offset when max_pre = -np.inf, to avoid nan in the output
# Switch statement is placed directly inside sum to break the self-symmetry
# of the returned output (otherwise the rewrite would not stabilize)
ret = max_pre_exp + log(
at_sum(
switch(
isinf(max_pre_exp_keepdims),
exp(max_pre_exp_keepdims),
exp(pre_exp - max_pre_exp_keepdims),
),
axis=axis,
),
)

# Restore the dimshuffle op, if any.
if dimshuffle_op:
ret = dimshuffle_op(ret)

return [ret]

class SoftmaxGrad(COp):
"""
Expand Down Expand Up @@ -789,4 +874,5 @@ def factorial(n):
"log_softmax",
"poch",
"factorial",
"logsumexp",
]
91 changes: 1 addition & 90 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from pytensor.tensor.math import Dot, Prod, Sum, _conj
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import add
from pytensor.tensor.math import all as at_all
Expand Down Expand Up @@ -3481,95 +3481,6 @@ def test_local_expm1():
)


def compile_graph_log_sum_exp(x, axis, dimshuffle_op=None):
sum_exp = at_sum(exp(x), axis=axis)
if dimshuffle_op:
sum_exp = dimshuffle_op(sum_exp)
y = log(sum_exp)
MODE = get_default_mode().including("local_log_sum_exp")
return function([x], y, mode=MODE)


def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
f = compile_graph_log_sum_exp(x, axis, dimshuffle_op)

fgraph = f.maker.fgraph.toposort()
for node in fgraph:
if (
hasattr(node.op, "scalar_op")
and node.op.scalar_op == aes.basic.scalar_maximum
):
return

# In mode FAST_COMPILE, the rewrites don't replace the
# `MaxAndArgmax` `Op`.
if isinstance(node.op, MaxAndArgmax):
return

# TODO FIXME: Refactor this test so that it makes a direct assertion and
# nothing more.
raise AssertionError("No maximum detected after log_sum_exp rewrite")


def test_local_log_sum_exp_maximum():
"""Test that the rewrite is applied by checking the presence of the maximum."""
x = tensor3("x")
check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(2,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)

# If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)

# If the sum is performed with keepdims=True
x = TensorType(dtype="floatX", shape=(None, 1, None))("x")
sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op)


def test_local_log_sum_exp_near_one():
"""Test that the rewritten result is correct around 1.0."""

x = tensor3("x")
x_val = 1.0 + np.random.random((4, 3, 2)).astype(config.floatX) / 10.0

f = compile_graph_log_sum_exp(x, axis=(1,))
naive_ret = np.log(np.sum(np.exp(x_val), axis=1))
rewritten_ret = f(x_val)
assert np.allclose(naive_ret, rewritten_ret)

# If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T)
rewritten_ret = f(x_val)
assert np.allclose(naive_ret, rewritten_ret)


def test_local_log_sum_exp_large():
"""Test that the rewrite result is correct for extreme value 100."""
x = vector("x")
f = compile_graph_log_sum_exp(x, axis=0)

x_val = np.array([-100.0, 100.0]).astype(config.floatX)

rewritten_ret = f(x_val)
assert np.allclose(rewritten_ret, 100.0)


def test_local_log_sum_exp_inf():
"""Test that when max = +-inf, the rewritten output still works correctly."""
x = vector("x")
f = compile_graph_log_sum_exp(x, axis=0)

assert f([-np.inf, -np.inf]) == -np.inf
assert f([np.inf, np.inf]) == np.inf
assert f([-np.inf, np.inf]) == np.inf


def test_local_reciprocal_1_plus_exp():
x = vector("x")
y = at.reciprocal(1 + exp(x))
Expand Down
Loading