Skip to content

Commit d0c9957

Browse files
committed
cython methods for group bins #1809
1 parent 0b18a91 commit d0c9957

File tree

3 files changed

+106
-8
lines changed

3 files changed

+106
-8
lines changed

pandas/core/groupby.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -754,12 +754,8 @@ def _aggregate(self, result, counts, values, how, is_numeric):
754754
raise NotImplementedError
755755
elif values.ndim > 2:
756756
for i, chunk in enumerate(values.transpose(2, 0, 1)):
757-
if not is_numeric:
758-
print 'getting results for %d' % i
759757
agg_func(result[:, :, i], counts, chunk.squeeze(),
760758
comp_ids)
761-
if not is_numeric:
762-
print 'got results for %d' % i
763759
else:
764760
agg_func(result, counts, values, comp_ids)
765761

@@ -937,14 +933,22 @@ def names(self):
937933
'last': lib.group_last_bin
938934
}
939935

936+
_cython_object_functions = {
937+
'first' : lambda a, b, c, d: lib.group_nth_bin_object(a, b, c, d, 1),
938+
'last' : lib.group_last_bin_object
939+
}
940+
940941
_name_functions = {
941942
'ohlc' : lambda *args: ['open', 'high', 'low', 'close']
942943
}
943944

944945
_filter_empty_groups = True
945946

946-
def _aggregate(self, result, counts, values, how):
947-
agg_func = self._cython_functions[how]
947+
def _aggregate(self, result, counts, values, how, is_numeric=True):
948+
fdict = self._cython_functions
949+
if not is_numeric:
950+
fdict = self._cython_object_functions
951+
agg_func = fdict[how]
948952
trans_func = self._cython_transforms.get(how, lambda x: x)
949953

950954
if values.ndim > 3:

pandas/src/groupby.pyx

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,54 @@ def group_nth_bin(ndarray[float64_t, ndim=2] out,
466466
else:
467467
out[i, j] = resx[i, j]
468468

469+
@cython.boundscheck(False)
470+
@cython.wraparound(False)
471+
def group_nth_bin_object(ndarray[object, ndim=2] out,
472+
ndarray[int64_t] counts,
473+
ndarray[object, ndim=2] values,
474+
ndarray[int64_t] bins, int64_t rank):
475+
'''
476+
Only aggregates on axis=0
477+
'''
478+
cdef:
479+
Py_ssize_t i, j, N, K, ngroups, b
480+
object val
481+
float64_t count
482+
ndarray[object, ndim=2] resx
483+
ndarray[float64_t, ndim=2] nobs
484+
485+
nobs = np.zeros((<object> out).shape, dtype=np.float64)
486+
resx = np.empty((<object> out).shape, dtype=object)
487+
488+
if bins[len(bins) - 1] == len(values):
489+
ngroups = len(bins)
490+
else:
491+
ngroups = len(bins) + 1
492+
493+
N, K = (<object> values).shape
494+
495+
b = 0
496+
for i in range(N):
497+
while b < ngroups - 1 and i >= bins[b]:
498+
b += 1
499+
500+
counts[b] += 1
501+
for j in range(K):
502+
val = values[i, j]
503+
504+
# not nan
505+
if val == val:
506+
nobs[b, j] += 1
507+
if nobs[b, j] == rank:
508+
resx[b, j] = val
509+
510+
for i in range(ngroups):
511+
for j in range(K):
512+
if nobs[i, j] == 0:
513+
out[i, j] = nan
514+
else:
515+
out[i, j] = resx[i, j]
516+
469517
@cython.boundscheck(False)
470518
@cython.wraparound(False)
471519
def group_last(ndarray[float64_t, ndim=2] out,
@@ -595,6 +643,53 @@ def group_last_bin(ndarray[float64_t, ndim=2] out,
595643
else:
596644
out[i, j] = resx[i, j]
597645

646+
@cython.boundscheck(False)
647+
@cython.wraparound(False)
648+
def group_last_bin_object(ndarray[object, ndim=2] out,
649+
ndarray[int64_t] counts,
650+
ndarray[object, ndim=2] values,
651+
ndarray[int64_t] bins):
652+
'''
653+
Only aggregates on axis=0
654+
'''
655+
cdef:
656+
Py_ssize_t i, j, N, K, ngroups, b
657+
object val
658+
float64_t count
659+
ndarray[object, ndim=2] resx
660+
ndarray[float64_t, ndim=2] nobs
661+
662+
nobs = np.zeros((<object> out).shape, dtype=np.float64)
663+
resx = np.empty((<object> out).shape, dtype=object)
664+
665+
if bins[len(bins) - 1] == len(values):
666+
ngroups = len(bins)
667+
else:
668+
ngroups = len(bins) + 1
669+
670+
N, K = (<object> values).shape
671+
672+
b = 0
673+
for i in range(N):
674+
while b < ngroups - 1 and i >= bins[b]:
675+
b += 1
676+
677+
counts[b] += 1
678+
for j in range(K):
679+
val = values[i, j]
680+
681+
# not nan
682+
if val == val:
683+
nobs[b, j] += 1
684+
resx[b, j] = val
685+
686+
for i in range(ngroups):
687+
for j in range(K):
688+
if nobs[i, j] == 0:
689+
out[i, j] = nan
690+
else:
691+
out[i, j] = resx[i, j]
692+
598693
#----------------------------------------------------------------------
599694
# group_min, group_max
600695

pandas/tseries/tests/test_timeseries.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ def test_frame_datetime64_handling_groupby(self):
10801080
(3,np.datetime64('2012-07-04'))],
10811081
columns = ['a', 'date'])
10821082
result = df.groupby('a').first()
1083-
self.assertEqual(result['date'][3].year, 2012)
1083+
self.assertEqual(result['date'][3], np.datetime64('2012-07-03'))
10841084

10851085
def test_series_interpolate_intraday(self):
10861086
# #1698
@@ -2190,4 +2190,3 @@ def test_hash_equivalent(self):
21902190
if __name__ == '__main__':
21912191
nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],
21922192
exit=False)
2193-

0 commit comments

Comments
 (0)