Skip to content

Commit f89ae34

Browse files
committed
Merge pull request #6132 from hayd/str_get_dummies
ENH get_dummies str method
2 parents a8bc986 + d8f94e9 commit f89ae34

File tree

6 files changed

+90
-5
lines changed

6 files changed

+90
-5
lines changed

doc/source/basics.rst

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,6 @@ can also be used.
11551155
Testing for Strings that Match or Contain a Pattern
11561156
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
11571157

1158-
11591158
You can check whether elements contain a pattern:
11601159

11611160
.. ipython:: python
@@ -1221,6 +1220,21 @@ Methods like ``match``, ``contains``, ``startswith``, and ``endswith`` take
12211220
``lower``,Equivalent to ``str.lower``
12221221
``upper``,Equivalent to ``str.upper``
12231222

1223+
1224+
Getting indicator variables from seperated strings
1225+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1226+
1227+
You can extract dummy variables from string columns.
1228+
For example if they are seperated by a ``'|'``:
1229+
1230+
.. ipython:: python
1231+
1232+
s = pd.Series(['a', 'a|b', np.nan, 'a|c'])
1233+
s.str.get_dummies(sep='|')
1234+
1235+
See also ``pd.get_dummies``.
1236+
1237+
12241238
.. _basics.sorting:
12251239

12261240
Sorting by index and value

doc/source/v0.13.1.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ API changes
4343
- Add ``-NaN`` and ``-nan`` to the default set of NA values (:issue:`5952`).
4444
See :ref:`NA Values <io.na_values>`.
4545

46+
- Added ``Series.str.get_dummies`` vectorized string method (:issue:`6021`), to extract
47+
dummy/indicator variables for seperated string columns:
48+
49+
.. ipython:: python
50+
51+
s = Series(['a', 'a|b', np.nan, 'a|c'])
52+
s.str.get_dummies(sep='|')
53+
4654
- Added the ``NDFrame.equals()`` method to compare if two NDFrames are
4755
equal have equal axes, dtypes, and values. Added the
4856
``array_equivalent`` function to compare if two ndarrays are

pandas/core/reshape.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,8 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False):
941941
1 0 1 0
942942
2 0 0 1
943943
944+
See also ``Series.str.get_dummies``.
945+
944946
"""
945947
# Series avoids inconsistent NaN handling
946948
cat = Categorical.from_array(Series(data))

pandas/core/strings.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def str_contains(arr, pat, case=True, flags=0, na=np.nan, regex=True):
187187
else:
188188
f = lambda x: pat in x
189189
return _na_map(f, arr, na)
190-
191190

192191

193192
def str_startswith(arr, pat, na=np.nan):
@@ -460,6 +459,46 @@ def f(x):
460459
return result
461460

462461

462+
def str_get_dummies(arr, sep='|'):
463+
"""
464+
Split each string by sep and return a frame of dummy/indicator variables.
465+
466+
Examples
467+
--------
468+
>>> Series(['a|b', 'a', 'a|c']).str.get_dummies()
469+
a b c
470+
0 1 1 0
471+
1 1 0 0
472+
2 1 0 1
473+
474+
>>> pd.Series(['a|b', np.nan, 'a|c']).str.get_dummies()
475+
a b c
476+
0 1 1 0
477+
1 0 0 0
478+
2 1 0 1
479+
480+
See also ``pd.get_dummies``.
481+
482+
"""
483+
# TODO remove this hack?
484+
arr = arr.fillna('')
485+
try:
486+
arr = sep + arr + sep
487+
except TypeError:
488+
arr = sep + arr.astype(str) + sep
489+
490+
tags = set()
491+
for ts in arr.str.split(sep):
492+
tags.update(ts)
493+
tags = sorted(tags - set([""]))
494+
495+
dummies = np.empty((len(arr), len(tags)), dtype=int)
496+
497+
for i, t in enumerate(tags):
498+
pat = sep + t + sep
499+
dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x)
500+
return DataFrame(dummies, arr.index, tags)
501+
463502

464503
def str_join(arr, sep):
465504
"""
@@ -843,7 +882,7 @@ def contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
843882
result = str_contains(self.series, pat, case=case, flags=flags,
844883
na=na, regex=regex)
845884
return self._wrap_result(result)
846-
885+
847886
@copy(str_replace)
848887
def replace(self, pat, repl, n=-1, case=True, flags=0):
849888
result = str_replace(self.series, pat, repl, n=n, case=case,
@@ -899,6 +938,11 @@ def rstrip(self, to_strip=None):
899938
result = str_rstrip(self.series, to_strip)
900939
return self._wrap_result(result)
901940

941+
@copy(str_get_dummies)
942+
def get_dummies(self, sep='|'):
943+
result = str_get_dummies(self.series, sep)
944+
return self._wrap_result(result)
945+
902946
count = _pat_wrapper(str_count, flags=True)
903947
startswith = _pat_wrapper(str_startswith, na=True)
904948
endswith = _pat_wrapper(str_endswith, na=True)

pandas/tests/test_strings.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ def test_replace(self):
366366
result = values.str.replace("(?<=\w),(?=\w)", ", ", flags=re.UNICODE)
367367
tm.assert_series_equal(result, exp)
368368

369-
370369
def test_repeat(self):
371370
values = Series(['a', 'b', NA, 'c', NA, 'd'])
372371

@@ -465,7 +464,7 @@ def test_extract(self):
465464
# Contains tests like those in test_match and some others.
466465

467466
values = Series(['fooBAD__barBAD', NA, 'foo'])
468-
er = [NA, NA] # empty row
467+
er = [NA, NA] # empty row
469468

470469
result = values.str.extract('.*(BAD[_]+).*(BAD)')
471470
exp = DataFrame([['BAD__', 'BAD'], er, er])
@@ -549,6 +548,19 @@ def test_extract(self):
549548
exp = DataFrame([['A', '1'], ['B', '2'], ['C', NA]], columns=['letter', 'number'])
550549
tm.assert_frame_equal(result, exp)
551550

551+
def test_get_dummies(self):
552+
s = Series(['a|b', 'a|c', np.nan])
553+
result = s.str.get_dummies('|')
554+
expected = DataFrame([[1, 1, 0], [1, 0, 1], [0, 0, 0]],
555+
columns=list('abc'))
556+
tm.assert_frame_equal(result, expected)
557+
558+
s = Series(['a;b', 'a', 7])
559+
result = s.str.get_dummies(';')
560+
expected = DataFrame([[0, 1, 1], [0, 1, 0], [1, 0, 0]],
561+
columns=list('7ab'))
562+
tm.assert_frame_equal(result, expected)
563+
552564
def test_join(self):
553565
values = Series(['a_b_c', 'c_d_e', np.nan, 'f_g_h'])
554566
result = values.str.split('_').str.join('_')

vb_suite/strings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def make_series(letters, strlen, size):
4545
strings_rstrip = Benchmark("many.str.rstrip('matchthis')", setup)
4646
strings_get = Benchmark("many.str.get(0)", setup)
4747

48+
setup = setup + """
49+
make_series(string.uppercase, strlen=10, size=10000).str.join('|')
50+
"""
51+
strings_get_dummies = Benchmark("s.str.get_dummies('|')", setup)
52+
4853
setup = common_setup + """
4954
import pandas.util.testing as testing
5055
ser = pd.Series(testing.makeUnicodeIndex())

0 commit comments

Comments
 (0)