Skip to content

Commit de3ece7

Browse files
authored
GH-98363: Add itertools.batched() (GH-98364)
1 parent 70732d8 commit de3ece7

File tree

5 files changed

+370
-39
lines changed

5 files changed

+370
-39
lines changed

Doc/library/itertools.rst

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Iterator Arguments Results
4848
Iterator Arguments Results Example
4949
============================ ============================ ================================================= =============================================================
5050
:func:`accumulate` p [,func] p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
51+
:func:`batched` p, n [p0, p1, ..., p_n-1], ... ``batched('ABCDEFG', n=3) --> ABC DEF G``
5152
:func:`chain` p, q, ... p0, p1, ... plast, q0, q1, ... ``chain('ABC', 'DEF') --> A B C D E F``
5253
:func:`chain.from_iterable` iterable p0, p1, ... plast, q0, q1, ... ``chain.from_iterable(['ABC', 'DEF']) --> A B C D E F``
5354
:func:`compress` data, selectors (d[0] if s[0]), (d[1] if s[1]), ... ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F``
@@ -170,6 +171,44 @@ loops that truncate the stream.
170171
.. versionchanged:: 3.8
171172
Added the optional *initial* parameter.
172173

174+
175+
.. function:: batched(iterable, n)
176+
177+
Batch data from the *iterable* into lists of length *n*. The last
178+
batch may be shorter than *n*.
179+
180+
Loops over the input iterable and accumulates data into lists up to
181+
size *n*. The input is consumed lazily, just enough to fill a list.
182+
The result is yielded as soon as the batch is full or when the input
183+
iterable is exhausted:
184+
185+
.. doctest::
186+
187+
>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
188+
>>> unflattened = list(batched(flattened_data, 2))
189+
>>> unflattened
190+
[['roses', 'red'], ['violets', 'blue'], ['sugar', 'sweet']]
191+
192+
>>> for batch in batched('ABCDEFG', 3):
193+
... print(batch)
194+
...
195+
['A', 'B', 'C']
196+
['D', 'E', 'F']
197+
['G']
198+
199+
Roughly equivalent to::
200+
201+
def batched(iterable, n):
202+
# batched('ABCDEFG', 3) --> ABC DEF G
203+
if n < 1:
204+
raise ValueError('n must be at least one')
205+
it = iter(iterable)
206+
while (batch := list(islice(it, n))):
207+
yield batch
208+
209+
.. versionadded:: 3.12
210+
211+
173212
.. function:: chain(*iterables)
174213

175214
Make an iterator that returns elements from the first iterable until it is
@@ -858,13 +897,6 @@ which incur interpreter overhead.
858897
else:
859898
raise ValueError('Expected fill, strict, or ignore')
860899
861-
def batched(iterable, n):
862-
"Batch data into lists of length n. The last batch may be shorter."
863-
# batched('ABCDEFG', 3) --> ABC DEF G
864-
it = iter(iterable)
865-
while (batch := list(islice(it, n))):
866-
yield batch
867-
868900
def triplewise(iterable):
869901
"Return overlapping triplets from an iterable"
870902
# triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
@@ -1211,36 +1243,6 @@ which incur interpreter overhead.
12111243
>>> list(grouper('abcdefg', n=3, incomplete='ignore'))
12121244
[('a', 'b', 'c'), ('d', 'e', 'f')]
12131245

1214-
>>> list(batched('ABCDEFG', 3))
1215-
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
1216-
>>> list(batched('ABCDEF', 3))
1217-
[['A', 'B', 'C'], ['D', 'E', 'F']]
1218-
>>> list(batched('ABCDE', 3))
1219-
[['A', 'B', 'C'], ['D', 'E']]
1220-
>>> list(batched('ABCD', 3))
1221-
[['A', 'B', 'C'], ['D']]
1222-
>>> list(batched('ABC', 3))
1223-
[['A', 'B', 'C']]
1224-
>>> list(batched('AB', 3))
1225-
[['A', 'B']]
1226-
>>> list(batched('A', 3))
1227-
[['A']]
1228-
>>> list(batched('', 3))
1229-
[]
1230-
>>> list(batched('ABCDEFG', 2))
1231-
[['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']]
1232-
>>> list(batched('ABCDEFG', 1))
1233-
[['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']]
1234-
>>> list(batched('ABCDEFG', 0))
1235-
[]
1236-
>>> list(batched('ABCDEFG', -1))
1237-
Traceback (most recent call last):
1238-
...
1239-
ValueError: Stop argument for islice() must be None or an integer: 0 <= x <= sys.maxsize.
1240-
>>> s = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
1241-
>>> all(list(flatten(batched(s[:n], 5))) == list(s[:n]) for n in range(len(s)))
1242-
True
1243-
12441246
>>> list(triplewise('ABCDEFG'))
12451247
[('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E'), ('D', 'E', 'F'), ('E', 'F', 'G')]
12461248

Lib/test/test_itertools.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ def test_accumulate(self):
159159
with self.assertRaises(TypeError):
160160
list(accumulate([10, 20], 100))
161161

162+
def test_batched(self):
163+
self.assertEqual(list(batched('ABCDEFG', 3)),
164+
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']])
165+
self.assertEqual(list(batched('ABCDEFG', 2)),
166+
[['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']])
167+
self.assertEqual(list(batched('ABCDEFG', 1)),
168+
[['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']])
169+
170+
with self.assertRaises(TypeError): # Too few arguments
171+
list(batched('ABCDEFG'))
172+
with self.assertRaises(TypeError):
173+
list(batched('ABCDEFG', 3, None)) # Too many arguments
174+
with self.assertRaises(TypeError):
175+
list(batched(None, 3)) # Non-iterable input
176+
with self.assertRaises(TypeError):
177+
list(batched('ABCDEFG', 'hello')) # n is a string
178+
with self.assertRaises(ValueError):
179+
list(batched('ABCDEFG', 0)) # n is zero
180+
with self.assertRaises(ValueError):
181+
list(batched('ABCDEFG', -1)) # n is negative
182+
183+
data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
184+
for n in range(1, 6):
185+
for i in range(len(data)):
186+
s = data[:i]
187+
batches = list(batched(s, n))
188+
with self.subTest(s=s, n=n, batches=batches):
189+
# Order is preserved and no data is lost
190+
self.assertEqual(''.join(chain(*batches)), s)
191+
# Each batch is an exact list
192+
self.assertTrue(all(type(batch) is list for batch in batches))
193+
# All but the last batch is of size n
194+
if batches:
195+
last_batch = batches.pop()
196+
self.assertTrue(all(len(batch) == n for batch in batches))
197+
self.assertTrue(len(last_batch) <= n)
198+
batches.append(last_batch)
199+
162200
def test_chain(self):
163201

164202
def chain2(*iterables):
@@ -1737,6 +1775,31 @@ def test_takewhile(self):
17371775

17381776
class TestPurePythonRoughEquivalents(unittest.TestCase):
17391777

1778+
def test_batched_recipe(self):
1779+
def batched_recipe(iterable, n):
1780+
"Batch data into lists of length n. The last batch may be shorter."
1781+
# batched('ABCDEFG', 3) --> ABC DEF G
1782+
if n < 1:
1783+
raise ValueError('n must be at least one')
1784+
it = iter(iterable)
1785+
while (batch := list(islice(it, n))):
1786+
yield batch
1787+
1788+
for iterable, n in product(
1789+
['', 'a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', None],
1790+
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]):
1791+
with self.subTest(iterable=iterable, n=n):
1792+
try:
1793+
e1, r1 = None, list(batched(iterable, n))
1794+
except Exception as e:
1795+
e1, r1 = type(e), None
1796+
try:
1797+
e2, r2 = None, list(batched_recipe(iterable, n))
1798+
except Exception as e:
1799+
e2, r2 = type(e), None
1800+
self.assertEqual(r1, r2)
1801+
self.assertEqual(e1, e2)
1802+
17401803
@staticmethod
17411804
def islice(iterable, *args):
17421805
s = slice(*args)
@@ -1788,6 +1851,10 @@ def test_accumulate(self):
17881851
a = []
17891852
self.makecycle(accumulate([1,2,a,3]), a)
17901853

1854+
def test_batched(self):
1855+
a = []
1856+
self.makecycle(batched([1,2,a,3], 2), a)
1857+
17911858
def test_chain(self):
17921859
a = []
17931860
self.makecycle(chain(a), a)
@@ -1972,6 +2039,18 @@ def test_accumulate(self):
19722039
self.assertRaises(TypeError, accumulate, N(s))
19732040
self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
19742041

2042+
def test_batched(self):
2043+
s = 'abcde'
2044+
r = [['a', 'b'], ['c', 'd'], ['e']]
2045+
n = 2
2046+
for g in (G, I, Ig, L, R):
2047+
with self.subTest(g=g):
2048+
self.assertEqual(list(batched(g(s), n)), r)
2049+
self.assertEqual(list(batched(S(s), 2)), [])
2050+
self.assertRaises(TypeError, batched, X(s), 2)
2051+
self.assertRaises(TypeError, batched, N(s), 2)
2052+
self.assertRaises(ZeroDivisionError, list, batched(E(s), 2))
2053+
19752054
def test_chain(self):
19762055
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
19772056
for g in (G, I, Ig, S, L, R):
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added itertools.batched() to batch data into lists of a given length with
2+
the last list possibly being shorter than the others.

Modules/clinic/itertoolsmodule.c.h

Lines changed: 80 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)