diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 358c3f3724..d3a2ded6cb 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2840,32 +2840,7 @@ def logaddexp(*xs): return log(add(*[exp(x) for x in xs])) -def logsumexp(x, axis=None, keepdims=False): - """Compute the log of the sum of exponentials of input elements. - See ``scipy.special.logsumexp``. - - Parameters - ---------- - x : symbolic tensor - Input - - axis : None or int or tuple of ints, optional - Axis or axes over which the sum is taken. By default axis is None, - and all elements are summed. - - keepdims : bool, optional - If this is set to True, the axes which are reduced are left in the - result as dimensions with size one. With this option, the result will - broadcast correctly against the original array. - - Returns - ------- - tensor - - """ - - return log(sum(exp(x), axis=axis, keepdims=keepdims)) class MatMul(Op): @@ -3130,6 +3105,5 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None "ptp", "power", "logaddexp", - "logsumexp", "hyp2f1", ] diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b0d124f1c8..20d06e9e6d 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -65,9 +65,7 @@ log, log1mexp, log1p, - makeKeepDims, ) -from pytensor.tensor.math import max as at_max from pytensor.tensor.math import maximum, mul, neg from pytensor.tensor.math import pow as at_pow from pytensor.tensor.math import ( @@ -2422,56 +2420,6 @@ def local_log_add_exp(fgraph, node): return [ret] -@register_stabilize -@register_specialize -@node_rewriter([log]) -def local_log_sum_exp(fgraph, node): - # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) - - if node.op != log: - return - - sum_node = node.inputs[0].owner - # If the sum has keepdims=True, there might be a dimshuffle - if sum_node and isinstance(sum_node.op, DimShuffle): - dimshuffle_op = sum_node.op - sum_node = sum_node.inputs[0].owner - else: - dimshuffle_op = None - - if not sum_node or not isinstance(sum_node.op, Sum): - return - - exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis - if not exp_node or not ( - isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, aes.Exp) - ): - return - - pre_exp = exp_node.inputs[0] - max_pre_exp = at_max(pre_exp, axis=axis) - max_pre_exp_keepdims = makeKeepDims(pre_exp, max_pre_exp, axis) - - # Do not offset when max_pre = -np.inf, to avoid nan in the output - # Switch statement is placed directly inside sum to break the self-symmetry - # of the returned output (otherwise the rewrite would not stabilize) - ret = max_pre_exp + log( - at_sum( - switch( - isinf(max_pre_exp_keepdims), - exp(max_pre_exp_keepdims), - exp(pre_exp - max_pre_exp_keepdims), - ), - axis=axis, - ), - ) - - # Restore the dimshuffle op, if any. - if dimshuffle_op: - ret = dimshuffle_op(ret) - - return [ret] - def add_calculate(num, denum, aslist=False, out_type=None): # TODO: make sure that this function and mul_calculate are similar diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index 1342c281cd..bfd1695188 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -1,14 +1,99 @@ import warnings from textwrap import dedent - import numpy as np import scipy - +import pytensor.scalar.basic as aes from pytensor.graph.basic import Apply from pytensor.link.c.op import COp -from pytensor.tensor.basic import as_tensor_variable -from pytensor.tensor.math import gamma, neg, sum +from pytensor.tensor.basic import as_tensor_variable, switch +from pytensor.tensor.math import gamma, neg, sum, Sum, makeKeepDims, isinf, log, exp +from pytensor.tensor.math import max as at_max +from pytensor.tensor.math import sum as at_sum +from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.rewriting.basic import ( + register_specialize, + register_stabilize, + +) +from pytensor.graph.rewriting.basic import node_rewriter + + +def logsumexp(x, axis=None, keepdims=False): + """Compute the log of the sum of exponentials of input elements. + + See ``scipy.special.logsumexp``. + + Parameters + ---------- + x : symbolic tensor + Input + + axis : None or int or tuple of ints, optional + Axis or axes over which the sum is taken. By default axis is None, + and all elements are summed. + + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the original array. + + Returns + ------- + tensor + + """ + return log(sum(exp(x), axis=axis, keepdims=keepdims)) + +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log_sum_exp(fgraph, node): + # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) + + if node.op != log: + return + + sum_node = node.inputs[0].owner + # If the sum has keepdims=True, there might be a dimshuffle + if sum_node and isinstance(sum_node.op, DimShuffle): + dimshuffle_op = sum_node.op + sum_node = sum_node.inputs[0].owner + else: + dimshuffle_op = None + + if not sum_node or not isinstance(sum_node.op, Sum): + return + + exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis + if not exp_node or not ( + isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, aes.Exp) + ): + return + + pre_exp = exp_node.inputs[0] + max_pre_exp = at_max(pre_exp, axis=axis) + max_pre_exp_keepdims = makeKeepDims(pre_exp, max_pre_exp, axis) + + # Do not offset when max_pre = -np.inf, to avoid nan in the output + # Switch statement is placed directly inside sum to break the self-symmetry + # of the returned output (otherwise the rewrite would not stabilize) + ret = max_pre_exp + log( + at_sum( + switch( + isinf(max_pre_exp_keepdims), + exp(max_pre_exp_keepdims), + exp(pre_exp - max_pre_exp_keepdims), + ), + axis=axis, + ), + ) + + # Restore the dimshuffle op, if any. + if dimshuffle_op: + ret = dimshuffle_op(ret) + + return [ret] class SoftmaxGrad(COp): """ @@ -789,4 +874,5 @@ def factorial(n): "log_softmax", "poch", "factorial", + "logsumexp", ] diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index f69879a51d..53cbca0229 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -34,7 +34,7 @@ from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj +from pytensor.tensor.math import Dot, Prod, Sum, _conj from pytensor.tensor.math import abs as at_abs from pytensor.tensor.math import add from pytensor.tensor.math import all as at_all @@ -3481,95 +3481,6 @@ def test_local_expm1(): ) -def compile_graph_log_sum_exp(x, axis, dimshuffle_op=None): - sum_exp = at_sum(exp(x), axis=axis) - if dimshuffle_op: - sum_exp = dimshuffle_op(sum_exp) - y = log(sum_exp) - MODE = get_default_mode().including("local_log_sum_exp") - return function([x], y, mode=MODE) - - -def check_max_log_sum_exp(x, axis, dimshuffle_op=None): - f = compile_graph_log_sum_exp(x, axis, dimshuffle_op) - - fgraph = f.maker.fgraph.toposort() - for node in fgraph: - if ( - hasattr(node.op, "scalar_op") - and node.op.scalar_op == aes.basic.scalar_maximum - ): - return - - # In mode FAST_COMPILE, the rewrites don't replace the - # `MaxAndArgmax` `Op`. - if isinstance(node.op, MaxAndArgmax): - return - - # TODO FIXME: Refactor this test so that it makes a direct assertion and - # nothing more. - raise AssertionError("No maximum detected after log_sum_exp rewrite") - - -def test_local_log_sum_exp_maximum(): - """Test that the rewrite is applied by checking the presence of the maximum.""" - x = tensor3("x") - check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None) - check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None) - check_max_log_sum_exp(x, axis=(2,), dimshuffle_op=None) - check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=None) - check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None) - - # If a transpose is applied to the sum - transpose_op = DimShuffle((False, False), (1, 0)) - check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op) - - # If the sum is performed with keepdims=True - x = TensorType(dtype="floatX", shape=(None, 1, None))("x") - sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op - check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op) - - -def test_local_log_sum_exp_near_one(): - """Test that the rewritten result is correct around 1.0.""" - - x = tensor3("x") - x_val = 1.0 + np.random.random((4, 3, 2)).astype(config.floatX) / 10.0 - - f = compile_graph_log_sum_exp(x, axis=(1,)) - naive_ret = np.log(np.sum(np.exp(x_val), axis=1)) - rewritten_ret = f(x_val) - assert np.allclose(naive_ret, rewritten_ret) - - # If a transpose is applied - transpose_op = DimShuffle((False, False), (1, 0)) - f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op) - naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T) - rewritten_ret = f(x_val) - assert np.allclose(naive_ret, rewritten_ret) - - -def test_local_log_sum_exp_large(): - """Test that the rewrite result is correct for extreme value 100.""" - x = vector("x") - f = compile_graph_log_sum_exp(x, axis=0) - - x_val = np.array([-100.0, 100.0]).astype(config.floatX) - - rewritten_ret = f(x_val) - assert np.allclose(rewritten_ret, 100.0) - - -def test_local_log_sum_exp_inf(): - """Test that when max = +-inf, the rewritten output still works correctly.""" - x = vector("x") - f = compile_graph_log_sum_exp(x, axis=0) - - assert f([-np.inf, -np.inf]) == -np.inf - assert f([np.inf, np.inf]) == np.inf - assert f([-np.inf, np.inf]) == np.inf - - def test_local_reciprocal_1_plus_exp(): x = vector("x") y = at.reciprocal(1 + exp(x)) diff --git a/tests/tensor/rewriting/test_special.py b/tests/tensor/rewriting/test_special.py index a269bfb1ad..56be2068d7 100644 --- a/tests/tensor/rewriting/test_special.py +++ b/tests/tensor/rewriting/test_special.py @@ -3,16 +3,22 @@ import pytensor from pytensor import shared +import pytensor.scalar as aes from pytensor.compile import optdb -from pytensor.compile.mode import get_mode +from pytensor.compile.function import function +from pytensor.compile.mode import get_mode, get_default_mode from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.db import RewriteDatabaseQuery -from pytensor.tensor.math import add, exp, log, true_div +from pytensor.tensor.math import add, exp, log, true_div, MaxAndArgmax from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, softmax -from pytensor.tensor.type import matrix +from pytensor.tensor.type import matrix, tensor3, TensorType, vector from tests import unittest_tools as utt +from pytensor.tensor.math import sum as at_sum +from pytensor.tensor.elemwise import DimShuffle + + _fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"]) @@ -130,3 +136,88 @@ def f(inputs): return pytensor.grad(None, x, known_grads={y: inputs}) utt.verify_grad(f, [rng.random((3, 4))]) + +def compile_graph_log_sum_exp(x, axis, dimshuffle_op=None): + sum_exp = at_sum(exp(x), axis=axis) + if dimshuffle_op: + sum_exp = dimshuffle_op(sum_exp) + y = log(sum_exp) + MODE = get_default_mode().including("local_log_sum_exp") + return function([x], y, mode=MODE) + + +def check_max_log_sum_exp(x, axis, dimshuffle_op=None): + f = compile_graph_log_sum_exp(x, axis, dimshuffle_op) + + fgraph = f.maker.fgraph.toposort() + for node in fgraph: + if ( + hasattr(node.op, "scalar_op") + and node.op.scalar_op == aes.basic.scalar_maximum + ): + return + + # In mode FAST_COMPILE, the rewrites don't replace the + # `MaxAndArgmax` `Op`. + if isinstance(node.op, MaxAndArgmax): + return + + # TODO FIXME: Refactor this test so that it makes a direct assertion and + # nothing more. + raise AssertionError("No maximum detected after log_sum_exp rewrite") + +def test_local_log_sum_exp_maximum(): + """Test that the rewrite is applied by checking the presence of the maximum.""" + x = tensor3("x") + check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None) + check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None) + check_max_log_sum_exp(x, axis=(2,), dimshuffle_op=None) + check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=None) + check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None) + + # If a transpose is applied to the sum + transpose_op = DimShuffle((False, False), (1, 0)) + check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op) + + # If the sum is performed with keepdims=True + x = TensorType(dtype="floatX", shape=(None, 1, None))("x") + sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op + check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op) + +def test_local_log_sum_exp_near_one(): + """Test that the rewritten result is correct around 1.0.""" + + x = tensor3("x") + x_val = 1.0 + np.random.random((4, 3, 2)).astype(config.floatX) / 10.0 + + f = compile_graph_log_sum_exp(x, axis=(1,)) + naive_ret = np.log(np.sum(np.exp(x_val), axis=1)) + rewritten_ret = f(x_val) + assert np.allclose(naive_ret, rewritten_ret) + + # If a transpose is applied + transpose_op = DimShuffle((False, False), (1, 0)) + f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op) + naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T) + rewritten_ret = f(x_val) + assert np.allclose(naive_ret, rewritten_ret) + +def test_local_log_sum_exp_large(): + """Test that the rewrite result is correct for extreme value 100.""" + x = vector("x") + f = compile_graph_log_sum_exp(x, axis=0) + + x_val = np.array([-100.0, 100.0]).astype(config.floatX) + + rewritten_ret = f(x_val) + assert np.allclose(rewritten_ret, 100.0) + + +def test_local_log_sum_exp_inf(): + """Test that when max = +-inf, the rewritten output still works correctly.""" + x = vector("x") + f = compile_graph_log_sum_exp(x, axis=0) + + assert f([-np.inf, -np.inf]) == -np.inf + assert f([np.inf, np.inf]) == np.inf + assert f([-np.inf, np.inf]) == np.inf diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 8ee5ac4544..276f205e69 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -8,8 +8,6 @@ import numpy as np import pytest from numpy.testing import assert_array_equal -from scipy.special import logsumexp as scipy_logsumexp - import pytensor.scalar as aes from pytensor.compile.debugmode import DebugMode from pytensor.compile.function import function @@ -80,7 +78,6 @@ log2, log10, logaddexp, - logsumexp, matmul, max, max_and_argmax, @@ -3382,18 +3379,7 @@ def test_logaddexp(): "keepdims", [True, False], ) -def test_logsumexp(shape, axis, keepdims): - scipy_inp = np.zeros(shape) - scipy_out = scipy_logsumexp(scipy_inp, axis=axis, keepdims=keepdims) - - pytensor_inp = as_tensor_variable(scipy_inp) - f = function([], logsumexp(pytensor_inp, axis=axis, keepdims=keepdims)) - pytensor_out = f() - np.testing.assert_array_almost_equal( - pytensor_out, - scipy_out, - ) def test_pprint(): diff --git a/tests/tensor/test_special.py b/tests/tensor/test_special.py index 17a9c05eff..8cd24ccdd8 100644 --- a/tests/tensor/test_special.py +++ b/tests/tensor/test_special.py @@ -4,7 +4,7 @@ from scipy.special import log_softmax as scipy_log_softmax from scipy.special import poch as scipy_poch from scipy.special import softmax as scipy_softmax - +from scipy.special import logsumexp as scipy_logsumexp from pytensor.compile.function import function from pytensor.configdefaults import config from pytensor.tensor.special import ( @@ -15,7 +15,9 @@ log_softmax, poch, softmax, + logsumexp, ) +from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.type import matrix, tensor3, tensor4, vector, vectors from tests import unittest_tools as utt from tests.tensor.utils import random_ranged @@ -171,3 +173,16 @@ def test_factorial(n): np.testing.assert_allclose( actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5 ) + +def test_logsumexp(shape, axis, keepdims): + scipy_inp = np.zeros(shape) + scipy_out = scipy_logsumexp(scipy_inp, axis=axis, keepdims=keepdims) + + pytensor_inp = as_tensor_variable(scipy_inp) + f = function([], logsumexp(pytensor_inp, axis=axis, keepdims=keepdims)) + pytensor_out = f() + + np.testing.assert_array_almost_equal( + pytensor_out, + scipy_out, + )