Skip to content

Commit fed6431

Browse files
committed
Faster factorize
1 parent 3794f70 commit fed6431

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

flox/core.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,26 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
556556
return offset, size
557557

558558

559+
def fast_isin(ar1, ar2, invert):
560+
rev_idx, ar1 = pd.factorize(ar1, sort=False)
561+
562+
ar = np.concatenate((ar1, ar2))
563+
# We need this to be a stable sort, so always use 'mergesort'
564+
# here. The values from the first array should always come before
565+
# the values from the second array.
566+
order = ar.argsort(kind="mergesort")
567+
sar = ar[order]
568+
if invert:
569+
bool_ar = sar[1:] != sar[:-1]
570+
else:
571+
bool_ar = sar[1:] == sar[:-1]
572+
flag = np.concatenate((bool_ar, [invert]))
573+
ret = np.empty(ar.shape, dtype=bool)
574+
ret[order] = flag
575+
576+
return ret[rev_idx]
577+
578+
559579
@overload
560580
def factorize_(
561581
by: T_Bys,
@@ -654,14 +674,20 @@ def factorize_(
654674
if expect is not None and reindex:
655675
sorter = np.argsort(expect)
656676
groups = expect[(sorter,)] if sort else expect
657-
idx = np.searchsorted(expect, flat, sorter=sorter)
658-
mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
677+
678+
mask = fast_isin(flat, expect, invert=True)
679+
if not np.issubdtype(flat.dtype, np.integer):
680+
mask |= isnull(flat)
681+
682+
idx = np.full(flat.shape, -1)
683+
result = np.searchsorted(expect.values, flat[~mask], sorter=sorter)
684+
idx[~mask] = result
685+
# idx = np.searchsorted(expect.values, flat, sorter=sorter)
686+
# idx[mask] = -1
659687
if not sort:
660688
# idx is the index in to the sorted array.
661689
# if we didn't want sorting, unsort it back
662-
idx[(idx == len(expect),)] = -1
663690
idx = sorter[(idx,)]
664-
idx[mask] = -1
665691
else:
666692
idx, groups = pd.factorize(flat, sort=sort) # type: ignore[arg-type]
667693

0 commit comments

Comments
 (0)