@@ -799,10 +799,50 @@ which incur interpreter overhead.
799
799
"Returns the sequence elements n times"
800
800
return chain.from_iterable(repeat(tuple(iterable), n))
801
801
802
+ def batched(iterable, n):
803
+ "Batch data into tuples of length n. The last batch may be shorter."
804
+ # batched('ABCDEFG', 3) --> ABC DEF G
805
+ if n < 1:
806
+ raise ValueError('n must be at least one')
807
+ it = iter(iterable)
808
+ while (batch := tuple(islice(it, n))):
809
+ yield batch
810
+
811
+ def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
812
+ "Collect data into non-overlapping fixed-length chunks or blocks"
813
+ # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
814
+ # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
815
+ # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
816
+ args = [iter(iterable)] * n
817
+ if incomplete == 'fill':
818
+ return zip_longest(*args, fillvalue=fillvalue)
819
+ if incomplete == 'strict':
820
+ return zip(*args, strict=True)
821
+ if incomplete == 'ignore':
822
+ return zip(*args)
823
+ else:
824
+ raise ValueError('Expected fill, strict, or ignore')
825
+
802
826
def sumprod(vec1, vec2):
803
827
"Compute a sum of products."
804
828
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))
805
829
830
+ def sum_of_squares(it):
831
+ "Add up the squares of the input values."
832
+ # sum_of_squares([10, 20, 30]) -> 1400
833
+ return sumprod(*tee(it))
834
+
835
+ def transpose(it):
836
+ "Swap the rows and columns of the input."
837
+ # transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33)
838
+ return zip(*it, strict=True)
839
+
840
+ def matmul(m1, m2):
841
+ "Multiply two matrices."
842
+ # matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]) --> (49, 80), (41, 60)
843
+ n = len(m2[0])
844
+ return batched(starmap(sumprod, product(m1, transpose(m2))), n)
845
+
806
846
def convolve(signal, kernel):
807
847
# See: https://betterexplained.com/articles/intuitive-convolution/
808
848
# convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
@@ -886,30 +926,6 @@ which incur interpreter overhead.
886
926
return starmap(func, repeat(args))
887
927
return starmap(func, repeat(args, times))
888
928
889
- def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
890
- "Collect data into non-overlapping fixed-length chunks or blocks"
891
- # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
892
- # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
893
- # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
894
- args = [iter(iterable)] * n
895
- if incomplete == 'fill':
896
- return zip_longest(*args, fillvalue=fillvalue)
897
- if incomplete == 'strict':
898
- return zip(*args, strict=True)
899
- if incomplete == 'ignore':
900
- return zip(*args)
901
- else:
902
- raise ValueError('Expected fill, strict, or ignore')
903
-
904
- def batched(iterable, n):
905
- "Batch data into tuples of length n. The last batch may be shorter."
906
- # batched('ABCDEFG', 3) --> ABC DEF G
907
- if n < 1:
908
- raise ValueError('n must be at least one')
909
- it = iter(iterable)
910
- while (batch := tuple(islice(it, n))):
911
- yield batch
912
-
913
929
def triplewise(iterable):
914
930
"Return overlapping triplets from an iterable"
915
931
# triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
@@ -1184,6 +1200,17 @@ which incur interpreter overhead.
1184
1200
>>> sumprod([1 ,2 ,3 ], [4 ,5 ,6 ])
1185
1201
32
1186
1202
1203
+ >>> sum_of_squares([10 , 20 , 30 ])
1204
+ 1400
1205
+
1206
+ >>> list (transpose([(1 , 2 , 3 ), (11 , 22 , 33 )]))
1207
+ [(1, 11), (2, 22), (3, 33)]
1208
+
1209
+ >>> list (matmul([(7 , 5 ), (3 , 5 )], [[2 , 5 ], [7 , 9 ]]))
1210
+ [(49, 80), (41, 60)]
1211
+ >>> list (matmul([[2 , 5 ], [7 , 9 ], [3 , 4 ]], [[7 , 11 , 5 , 4 , 9 ], [3 , 5 , 2 , 6 , 3 ]]))
1212
+ [(29, 47, 20, 38, 33), (76, 122, 53, 82, 90), (33, 53, 23, 36, 39)]
1213
+
1187
1214
>>> data = [20 , 40 , 24 , 32 , 20 , 28 , 16 ]
1188
1215
>>> list (convolve(data, [0.25 , 0.25 , 0.25 , 0.25 ]))
1189
1216
[5.0, 15.0, 21.0, 29.0, 29.0, 26.0, 24.0, 16.0, 11.0, 4.0]
0 commit comments