Skip to content

Commit 039fbeb

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[dynamo] fix functools.reduce() function with None as initial (#116398)
The `initial` argument in `functools.reduce` can be `None`. ```python initial_missing = object() def reduce(function, iterable, initial=initial_missing, /): it = iter(iterable) if initial is initial_missing: value = next(it) else: value = initial for element in it: value = function(value, element) return value ``` Reference: - python/cpython#102759 Pull Request resolved: #116398 Approved by: https://github.com/Skylion007
1 parent c7e9c15 commit 039fbeb

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

test/dynamo/test_functions.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,19 @@ def func_with_default(a, b, some_default_arg=True):
6363
return a - b
6464

6565

66-
def make_test(fn):
66+
def make_test(fn=None, expected_frame_count=1):
67+
if fn is None:
68+
return lambda fn: make_test(fn, expected_frame_count=expected_frame_count)
69+
6770
nargs = len(inspect.signature(fn).parameters)
6871

6972
def test_fn(self):
70-
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs)
73+
return torch._dynamo.testing.standard_test(
74+
self,
75+
fn=fn,
76+
nargs=nargs,
77+
expected_frame_count=expected_frame_count,
78+
)
7179

7280
return test_fn
7381

@@ -870,6 +878,22 @@ def test_map_sum(a, b, c, d):
870878
def test_reduce(a, b, c, d):
871879
return functools.reduce(operator.add, [a, b, c, d])
872880

881+
@make_test
882+
def test_reduce_with_initial(a, b, c, d):
883+
return functools.reduce(operator.add, [b, c, d], a)
884+
885+
@make_test(expected_frame_count=0)
886+
def test_reduce_with_single(x):
887+
return functools.reduce(lambda a, b: (a, b), [x])
888+
889+
@make_test(expected_frame_count=0)
890+
def test_reduce_with_single_with_initial(x, y):
891+
return functools.reduce(lambda a, b: (a, b), [y], x)
892+
893+
@make_test(expected_frame_count=0)
894+
def test_reduce_with_none_initial(x):
895+
return functools.reduce(lambda a, b: (a, b), [x], None)
896+
873897
@make_test
874898
def test_tuple_contains(a, b):
875899
v1 = "a"

torch/_dynamo/testing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,14 @@ def normalize_gm(gm_str) -> str:
244244
return remove_trailing_space(strip_comment(gm_str))
245245

246246

247-
def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None):
247+
def standard_test(
248+
self,
249+
fn,
250+
nargs,
251+
expected_ops=None,
252+
expected_ops_dynamic=None,
253+
expected_frame_count=1,
254+
):
248255
if not config.assume_static_by_default and expected_ops_dynamic is not None:
249256
expected_ops = expected_ops_dynamic
250257

@@ -265,7 +272,7 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None)
265272
self.assertTrue(same(val1b, correct1))
266273
self.assertTrue(same(val2a, correct2))
267274
self.assertTrue(same(val2b, correct2))
268-
self.assertEqual(actual.frame_count, 1)
275+
self.assertEqual(actual.frame_count, expected_frame_count)
269276
if expected_ops is not None:
270277
self.assertEqual(actual.op_count, expected_ops)
271278

torch/_dynamo/variables/builtin.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def call_fn(self, tx, *args, **kwargs):
8686

8787

8888
class BuiltinVariable(VariableTracker):
89+
_SENTINEL = object()
90+
8991
@staticmethod
9092
@functools.lru_cache(None)
9193
def _constant_fold_functions():
@@ -1100,13 +1102,13 @@ def call_sum(self, tx, seq, **kwargs):
11001102
{},
11011103
)
11021104

1103-
def call_reduce(self, tx, function, iterable, initializer=None):
1105+
def call_reduce(self, tx, function, iterable, initial=_SENTINEL):
11041106
if iterable.has_unpack_var_sequence(tx):
11051107
items = iterable.unpack_var_sequence(tx)
1106-
if initializer is None:
1108+
if initial is self._SENTINEL:
11071109
value, items = items[0], items[1:]
11081110
else:
1109-
value = initializer
1111+
value = initial
11101112
for element in items:
11111113
value = function.call_function(tx, [value, element], {})
11121114
return value

0 commit comments

Comments
 (0)