Skip to content

Commit ba88628

Browse files
authored
[3.11] Add recipes to showcase tee(), zip*, batched, starmap, and product. (GH-101028)
1 parent d06315a commit ba88628

File tree

1 file changed

+51
-24
lines changed

1 file changed

+51
-24
lines changed

Doc/library/itertools.rst

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -799,10 +799,50 @@ which incur interpreter overhead.
799799
"Returns the sequence elements n times"
800800
return chain.from_iterable(repeat(tuple(iterable), n))
801801

802+
def batched(iterable, n):
803+
"Batch data into tuples of length n. The last batch may be shorter."
804+
# batched('ABCDEFG', 3) --> ABC DEF G
805+
if n < 1:
806+
raise ValueError('n must be at least one')
807+
it = iter(iterable)
808+
while (batch := tuple(islice(it, n))):
809+
yield batch
810+
811+
def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
812+
"Collect data into non-overlapping fixed-length chunks or blocks"
813+
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
814+
# grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
815+
# grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
816+
args = [iter(iterable)] * n
817+
if incomplete == 'fill':
818+
return zip_longest(*args, fillvalue=fillvalue)
819+
if incomplete == 'strict':
820+
return zip(*args, strict=True)
821+
if incomplete == 'ignore':
822+
return zip(*args)
823+
else:
824+
raise ValueError('Expected fill, strict, or ignore')
825+
802826
def sumprod(vec1, vec2):
803827
"Compute a sum of products."
804828
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))
805829

830+
def sum_of_squares(it):
831+
"Add up the squares of the input values."
832+
# sum_of_squares([10, 20, 30]) -> 1400
833+
return sumprod(*tee(it))
834+
835+
def transpose(it):
836+
"Swap the rows and columns of the input."
837+
# transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33)
838+
return zip(*it, strict=True)
839+
840+
def matmul(m1, m2):
841+
"Multiply two matrices."
842+
# matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]) --> (49, 80), (41, 60)
843+
n = len(m2[0])
844+
return batched(starmap(sumprod, product(m1, transpose(m2))), n)
845+
806846
def convolve(signal, kernel):
807847
# See: https://betterexplained.com/articles/intuitive-convolution/
808848
# convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
@@ -886,30 +926,6 @@ which incur interpreter overhead.
886926
return starmap(func, repeat(args))
887927
return starmap(func, repeat(args, times))
888928

889-
def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
890-
"Collect data into non-overlapping fixed-length chunks or blocks"
891-
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
892-
# grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
893-
# grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
894-
args = [iter(iterable)] * n
895-
if incomplete == 'fill':
896-
return zip_longest(*args, fillvalue=fillvalue)
897-
if incomplete == 'strict':
898-
return zip(*args, strict=True)
899-
if incomplete == 'ignore':
900-
return zip(*args)
901-
else:
902-
raise ValueError('Expected fill, strict, or ignore')
903-
904-
def batched(iterable, n):
905-
"Batch data into tuples of length n. The last batch may be shorter."
906-
# batched('ABCDEFG', 3) --> ABC DEF G
907-
if n < 1:
908-
raise ValueError('n must be at least one')
909-
it = iter(iterable)
910-
while (batch := tuple(islice(it, n))):
911-
yield batch
912-
913929
def triplewise(iterable):
914930
"Return overlapping triplets from an iterable"
915931
# triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
@@ -1184,6 +1200,17 @@ which incur interpreter overhead.
11841200
>>> sumprod([1,2,3], [4,5,6])
11851201
32
11861202

1203+
>>> sum_of_squares([10, 20, 30])
1204+
1400
1205+
1206+
>>> list(transpose([(1, 2, 3), (11, 22, 33)]))
1207+
[(1, 11), (2, 22), (3, 33)]
1208+
1209+
>>> list(matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]))
1210+
[(49, 80), (41, 60)]
1211+
>>> list(matmul([[2, 5], [7, 9], [3, 4]], [[7, 11, 5, 4, 9], [3, 5, 2, 6, 3]]))
1212+
[(29, 47, 20, 38, 33), (76, 122, 53, 82, 90), (33, 53, 23, 36, 39)]
1213+
11871214
>>> data = [20, 40, 24, 32, 20, 28, 16]
11881215
>>> list(convolve(data, [0.25, 0.25, 0.25, 0.25]))
11891216
[5.0, 15.0, 21.0, 29.0, 29.0, 26.0, 24.0, 16.0, 11.0, 4.0]

0 commit comments

Comments
 (0)