diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index f30d6c9d5d4c0..42fa3da67ff4b 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -8,7 +8,7 @@ import pandas as pd from numpy.random import randn, rand, randint np.random.seed(123456) - from pandas import DataFrame, Series, date_range, options + from pandas import DataFrame, Series, date_range, options, Categorical import pandas.util.testing as tm np.set_printoptions(precision=4, suppress=True) import matplotlib.pyplot as plt @@ -587,20 +587,43 @@ each point: plt.close('all') -You can pass other keywords supported by matplotlib ``scatter``. -Below example shows a bubble chart using a dataframe column values as bubble size. +You can also pass a column name as the ``s`` (size) argument to have +the point sizes scale according to that column's values. Currently +this is only supported for string column names. + +The minimum and +maximum sizes of the bubbles (in points) are controlled by the +``size_range`` argument, with a default range of ``(50, 1000)``. The +below example shows a bubble chart using a dataframe column values +as bubble size. + +.. ipython:: python + + @savefig scatter_plot_sizes.png + df.plot(kind='scatter', x='a', y='b', s='c'); + +.. ipython:: python + :suppress: + + plt.close('all') + +Categorical columns can also be used to set point sizes, producing +a set of equally spaced point sizes: .. ipython:: python - @savefig scatter_plot_bubble.png - df.plot(kind='scatter', x='a', y='b', s=df['c']*200); + df['group'] = Categorical(randint(1, 4, 50)) + @savefig scatter_plot_categorical_sizes.png + df.plot(kind='scatter', x='a', y='b', s='group') .. ipython:: python :suppress: plt.close('all') -See the :meth:`scatter ` method and the +You can pass other keywords supported by matplotlib ``scatter``, e.g. ``alpha`` +to control the transparency of points. See the +:meth:`scatter ` method and the `matplotlib scatter documenation `__ for more. .. _visualization.hexbin: diff --git a/doc/source/whatsnew/v0.15.2.txt b/doc/source/whatsnew/v0.15.2.txt index 6688f106f922e..6ecba9e0e0ad8 100644 --- a/doc/source/whatsnew/v0.15.2.txt +++ b/doc/source/whatsnew/v0.15.2.txt @@ -66,6 +66,7 @@ Enhancements - Added support for ``utcfromtimestamp()``, ``fromtimestamp()``, and ``combine()`` on `Timestamp` class (:issue:`5351`). - Added Google Analytics (`pandas.io.ga`) basic documentation (:issue:`8835`). See :ref:`here`. - Added flag ``order_categoricals`` to ``StataReader`` and ``read_stata`` to select whether to order imported categorical data (:issue:`8836`). See :ref:`here ` for more information on importing categorical variables from Stata data files. +- Added support for passing a column name as the size argument to ``DataFrame.plot(kind='scatter')``, along with a ``size_range`` argument to control scaling (:issue:`8244`). .. _whatsnew_0152.performance: diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 74ec6d22ca4cd..0cf886aa58b76 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -9,7 +9,8 @@ from datetime import datetime, date -from pandas import Series, DataFrame, MultiIndex, PeriodIndex, date_range +from pandas import (Series, DataFrame, MultiIndex, PeriodIndex, date_range, + Categorical) from pandas.compat import (range, lrange, StringIO, lmap, lzip, u, zip, iteritems, OrderedDict) from pandas.util.decorators import cache_readonly @@ -1645,6 +1646,32 @@ def test_plot_scatter_with_c(self): self.assertIs(ax.collections[0].colorbar, None) self._check_colors(ax.collections, facecolors=['r']) + @slow + def test_plot_scatter_with_size(self): + df = DataFrame(randn(6, 3), + index=list(string.ascii_letters[:6]), + columns=['x', 'y', 'z']) + df['group'] = Categorical(random.randint(1, 4, 6)) + + size_range = (100, 500) + ax1 = df.plot(kind='scatter', x='x', y='y', s='z', + size_range=size_range) + point_sizes1 = ax1.collections[0]._sizes + self.assertGreaterEqual(min(point_sizes1), size_range[0]) + self.assertLessEqual(max(point_sizes1), size_range[1]) + + # Categorical size column + ax2 = df.plot(kind='scatter', x='x', y='y', s='group', + size_range=size_range) + point_sizes2 = ax2.collections[0]._sizes + self.assertGreaterEqual(min(point_sizes2), size_range[0]) + self.assertLessEqual(max(point_sizes2), size_range[1]) + unique_sizes = np.unique(point_sizes2) + self.assertEqual( + len(unique_sizes), + len(df['group'].cat.categories) + ) + @slow def test_plot_bar(self): df = DataFrame(randn(6, 4), diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index cb669b75e5c96..75d8af651dcb8 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1374,7 +1374,8 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True): class ScatterPlot(MPLPlot): _layout_type = 'single' - def __init__(self, data, x, y, c=None, **kwargs): + def __init__(self, data, x, y, c=None, s=None, + size_range=(50, 1000), **kwargs): MPLPlot.__init__(self, data, **kwargs) if x is None or y is None: raise ValueError( 'scatter requires and x and y column') @@ -1387,6 +1388,16 @@ def __init__(self, data, x, y, c=None, **kwargs): self.x = x self.y = y self.c = c + self.size_range = size_range + + # Set up size scaling if necessary, need to do this before plot + # generation starts and non-numeric data thrown away + if s is None: + self.s_values = self.plt.rcParams['lines.markersize'] + elif isinstance(s, str) and s in self.data.columns: + self.s_values = self._convert_column_to_size(s) + else: + self.s_values = s @property def nseries(self): @@ -1415,12 +1426,14 @@ def _make_plot(self): else: c_values = c + if self.legend and hasattr(self, 'label'): label = self.label else: label = None scatter = ax.scatter(data[x].values, data[y].values, c=c_values, - label=label, cmap=cmap, **self.kwds) + s=self.s_values, label=label, cmap=cmap, + **self.kwds) if cb: img = ax.collections[0] kws = dict(ax=ax) @@ -1437,6 +1450,23 @@ def _make_plot(self): err_kwds['ecolor'] = scatter.get_facecolor()[0] ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds) + def _convert_column_to_size(self, col_name): + min_size, max_size = self.size_range + size_col = self.data[col_name] + + if com.is_categorical_dtype(size_col): + n_categories = len(size_col.cat.categories) + cat_sizes = np.linspace(min_size, max_size, num=n_categories) + size_mapper = Series(cat_sizes, index=size_col.cat.categories) + point_sizes = size_col.map(size_mapper) + else: + vals = self.data[col_name].values + val_range = vals.max() - vals.min() + normalized_vals = (vals - vals.min()) / val_range + point_sizes = (min_size + (normalized_vals * (max_size - min_size))) + + return point_sizes + def _post_plot_logic(self): ax = self.axes[0] x, y = self.x, self.y