Skip to content

ENH/VIS: Pass DataFrame column to size argument in DataFrame.scatter #8885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions doc/source/visualization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
from numpy.random import randn, rand, randint
np.random.seed(123456)
from pandas import DataFrame, Series, date_range, options
from pandas import DataFrame, Series, date_range, options, Categorical
import pandas.util.testing as tm
np.set_printoptions(precision=4, suppress=True)
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -587,20 +587,43 @@ each point:

plt.close('all')

You can pass other keywords supported by matplotlib ``scatter``.
Below example shows a bubble chart using a dataframe column values as bubble size.
You can also pass a column name as the ``s`` (size) argument to have
the point sizes scale according to that column's values. Currently
this is only supported for string column names.

The minimum and
maximum sizes of the bubbles (in points) are controlled by the
``size_range`` argument, with a default range of ``(50, 1000)``. The
below example shows a bubble chart using a dataframe column values
as bubble size.

.. ipython:: python

@savefig scatter_plot_sizes.png
df.plot(kind='scatter', x='a', y='b', s='c');

.. ipython:: python
:suppress:

plt.close('all')

Categorical columns can also be used to set point sizes, producing
a set of equally spaced point sizes:

.. ipython:: python

@savefig scatter_plot_bubble.png
df.plot(kind='scatter', x='a', y='b', s=df['c']*200);
df['group'] = Categorical(randint(1, 4, 50))
@savefig scatter_plot_categorical_sizes.png
df.plot(kind='scatter', x='a', y='b', s='group')

.. ipython:: python
:suppress:

plt.close('all')

See the :meth:`scatter <matplotlib.axes.Axes.scatter>` method and the
You can pass other keywords supported by matplotlib ``scatter``, e.g. ``alpha``
to control the transparency of points. See the
:meth:`scatter <matplotlib.axes.Axes.scatter>` method and the
`matplotlib scatter documenation <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`__ for more.

.. _visualization.hexbin:
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.15.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Enhancements
- Added support for ``utcfromtimestamp()``, ``fromtimestamp()``, and ``combine()`` on `Timestamp` class (:issue:`5351`).
- Added Google Analytics (`pandas.io.ga`) basic documentation (:issue:`8835`). See :ref:`here<remote_data.ga>`.
- Added flag ``order_categoricals`` to ``StataReader`` and ``read_stata`` to select whether to order imported categorical data (:issue:`8836`). See :ref:`here <io.stata-categorical>` for more information on importing categorical variables from Stata data files.
- Added support for passing a column name as the size argument to ``DataFrame.plot(kind='scatter')``, along with a ``size_range`` argument to control scaling (:issue:`8244`).

.. _whatsnew_0152.performance:

Expand Down
29 changes: 28 additions & 1 deletion pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from datetime import datetime, date

from pandas import Series, DataFrame, MultiIndex, PeriodIndex, date_range
from pandas import (Series, DataFrame, MultiIndex, PeriodIndex, date_range,
Categorical)
from pandas.compat import (range, lrange, StringIO, lmap, lzip, u, zip,
iteritems, OrderedDict)
from pandas.util.decorators import cache_readonly
Expand Down Expand Up @@ -1645,6 +1646,32 @@ def test_plot_scatter_with_c(self):
self.assertIs(ax.collections[0].colorbar, None)
self._check_colors(ax.collections, facecolors=['r'])

@slow
def test_plot_scatter_with_size(self):
df = DataFrame(randn(6, 3),
index=list(string.ascii_letters[:6]),
columns=['x', 'y', 'z'])
df['group'] = Categorical(random.randint(1, 4, 6))

size_range = (100, 500)
ax1 = df.plot(kind='scatter', x='x', y='y', s='z',
size_range=size_range)
point_sizes1 = ax1.collections[0]._sizes
self.assertGreaterEqual(min(point_sizes1), size_range[0])
self.assertLessEqual(max(point_sizes1), size_range[1])

# Categorical size column
ax2 = df.plot(kind='scatter', x='x', y='y', s='group',
size_range=size_range)
point_sizes2 = ax2.collections[0]._sizes
self.assertGreaterEqual(min(point_sizes2), size_range[0])
self.assertLessEqual(max(point_sizes2), size_range[1])
unique_sizes = np.unique(point_sizes2)
self.assertEqual(
len(unique_sizes),
len(df['group'].cat.categories)
)

@slow
def test_plot_bar(self):
df = DataFrame(randn(6, 4),
Expand Down
34 changes: 32 additions & 2 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,8 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
class ScatterPlot(MPLPlot):
_layout_type = 'single'

def __init__(self, data, x, y, c=None, **kwargs):
def __init__(self, data, x, y, c=None, s=None,
size_range=(50, 1000), **kwargs):
MPLPlot.__init__(self, data, **kwargs)
if x is None or y is None:
raise ValueError( 'scatter requires and x and y column')
Expand All @@ -1387,6 +1388,16 @@ def __init__(self, data, x, y, c=None, **kwargs):
self.x = x
self.y = y
self.c = c
self.size_range = size_range

# Set up size scaling if necessary, need to do this before plot
# generation starts and non-numeric data thrown away
if s is None:
self.s_values = self.plt.rcParams['lines.markersize']
elif isinstance(s, str) and s in self.data.columns:
self.s_values = self._convert_column_to_size(s)
else:
self.s_values = s

@property
def nseries(self):
Expand Down Expand Up @@ -1415,12 +1426,14 @@ def _make_plot(self):
else:
c_values = c


if self.legend and hasattr(self, 'label'):
label = self.label
else:
label = None
scatter = ax.scatter(data[x].values, data[y].values, c=c_values,
label=label, cmap=cmap, **self.kwds)
s=self.s_values, label=label, cmap=cmap,
**self.kwds)
if cb:
img = ax.collections[0]
kws = dict(ax=ax)
Expand All @@ -1437,6 +1450,23 @@ def _make_plot(self):
err_kwds['ecolor'] = scatter.get_facecolor()[0]
ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds)

def _convert_column_to_size(self, col_name):
min_size, max_size = self.size_range
size_col = self.data[col_name]

if com.is_categorical_dtype(size_col):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment or 2 to explain what you are doing here

n_categories = len(size_col.cat.categories)
cat_sizes = np.linspace(min_size, max_size, num=n_categories)
size_mapper = Series(cat_sizes, index=size_col.cat.categories)
point_sizes = size_col.map(size_mapper)
else:
vals = self.data[col_name].values
val_range = vals.max() - vals.min()
normalized_vals = (vals - vals.min()) / val_range
point_sizes = (min_size + (normalized_vals * (max_size - min_size)))

return point_sizes

def _post_plot_logic(self):
ax = self.axes[0]
x, y = self.x, self.y
Expand Down