From 40aec1b5699087150c1b44dc56c7f1ab142aaf73 Mon Sep 17 00:00:00 2001 From: David Stephens Date: Sun, 9 Nov 2014 12:30:27 -0800 Subject: [PATCH] BUG: Fix plots showing 2 sets of axis labels when the index is a timeseries. --- doc/source/whatsnew/v0.15.2.txt | 2 + pandas/tests/test_graphics.py | 26 ++++++++++++ pandas/tools/plotting.py | 71 ++++++++++++++++++++++----------- 3 files changed, 76 insertions(+), 23 deletions(-) diff --git a/doc/source/whatsnew/v0.15.2.txt b/doc/source/whatsnew/v0.15.2.txt index 58dc1da214c05..874eb9155e2a3 100644 --- a/doc/source/whatsnew/v0.15.2.txt +++ b/doc/source/whatsnew/v0.15.2.txt @@ -222,3 +222,5 @@ Bug Fixes - Fixed ValueError raised by cummin/cummax when datetime64 Series contains NaT. (:issue:`8965`) +- Bug in plotting if sharex was enabled and index was a timeseries, would show labels on multiple axes (:issue:`3964`). + diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index b06342e8ce3c3..20f4b13867188 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -1329,6 +1329,32 @@ def test_subplots_multiple_axes(self): self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) self.assertEqual(axes.shape, (1, )) + def test_subplots_ts_share_axes(self): + # GH 3964 + fig, axes = self.plt.subplots(3, 3, sharex=True, sharey=True) + self.plt.subplots_adjust(left=0.05, right=0.95, hspace=0.3, wspace=0.3) + df = DataFrame(np.random.randn(10, 9), index=date_range(start='2014-07-01', freq='M', periods=10)) + for i, ax in enumerate(axes.ravel()): + df[i].plot(ax=ax, fontsize=5) + + #Rows other than bottom should not be visible + for ax in axes[0:-1].ravel(): + self._check_visible(ax.get_xticklabels(), visible=False) + + #Bottom row should be visible + for ax in axes[-1].ravel(): + self._check_visible(ax.get_xticklabels(), visible=True) + + #First column should be visible + for ax in axes[[0, 1, 2], [0]].ravel(): + self._check_visible(ax.get_yticklabels(), visible=True) + + #Other columns should not be visible + for ax in axes[[0, 1, 2], [1]].ravel(): + self._check_visible(ax.get_yticklabels(), visible=False) + for ax in axes[[0, 1, 2], [2]].ravel(): + self._check_visible(ax.get_yticklabels(), visible=False) + def test_negative_log(self): df = - DataFrame(rand(6, 4), index=list(string.ascii_letters[:6]), diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index b9a96ee262101..9fa747f28069b 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1022,7 +1022,10 @@ def _post_plot_logic(self): def _adorn_subplots(self): to_adorn = self.axes - # todo: sharex, sharey handling? + if len(self.axes) > 0: + all_axes = self._get_axes() + nrows, ncols = self._get_axes_layout() + _handle_shared_axes(all_axes, len(all_axes), len(all_axes), nrows, ncols, self.sharex, self.sharey) for ax in to_adorn: if self.yticks is not None: @@ -1375,6 +1378,19 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True): errors[kw] = err return errors + def _get_axes(self): + return self.axes[0].get_figure().get_axes() + + def _get_axes_layout(self): + axes = self._get_axes() + x_set = set() + y_set = set() + for ax in axes: + # check axes coordinates to estimate layout + points = ax.get_position().get_points() + x_set.add(points[0][0]) + y_set.add(points[0][1]) + return (len(y_set), len(x_set)) class ScatterPlot(MPLPlot): _layout_type = 'single' @@ -3231,6 +3247,28 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True, ax = fig.add_subplot(nrows, ncols, i + 1, **kwds) axarr[i] = ax + _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey) + + if naxes != nplots: + for ax in axarr[naxes:]: + ax.set_visible(False) + + if squeeze: + # Reshape the array to have the final desired dimension (nrow,ncol), + # though discarding unneeded dimensions that equal 1. If we only have + # one subplot, just return it instead of a 1-element array. + if nplots == 1: + axes = axarr[0] + else: + axes = axarr.reshape(nrows, ncols).squeeze() + else: + # returned axis array will be always 2-d, even if nrows=ncols=1 + axes = axarr.reshape(nrows, ncols) + + return fig, axes + + +def _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey): if nplots > 1: if sharex and nrows > 1: @@ -3241,8 +3279,11 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True, # set_visible will not be effective if # minor axis has NullLocator and NullFormattor (default) import matplotlib.ticker as ticker - ax.xaxis.set_minor_locator(ticker.AutoLocator()) - ax.xaxis.set_minor_formatter(ticker.FormatStrFormatter('')) + + if isinstance(ax.xaxis.get_minor_locator(), ticker.NullLocator): + ax.xaxis.set_minor_locator(ticker.AutoLocator()) + if isinstance(ax.xaxis.get_minor_formatter(), ticker.NullFormatter): + ax.xaxis.set_minor_formatter(ticker.FormatStrFormatter('')) for label in ax.get_xticklabels(minor=True): label.set_visible(False) except Exception: # pragma no cover @@ -3255,32 +3296,16 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True, label.set_visible(False) try: import matplotlib.ticker as ticker - ax.yaxis.set_minor_locator(ticker.AutoLocator()) - ax.yaxis.set_minor_formatter(ticker.FormatStrFormatter('')) + if isinstance(ax.yaxis.get_minor_locator(), ticker.NullLocator): + ax.yaxis.set_minor_locator(ticker.AutoLocator()) + if isinstance(ax.yaxis.get_minor_formatter(), ticker.NullFormatter): + ax.yaxis.set_minor_formatter(ticker.FormatStrFormatter('')) for label in ax.get_yticklabels(minor=True): label.set_visible(False) except Exception: # pragma no cover pass ax.yaxis.get_label().set_visible(False) - if naxes != nplots: - for ax in axarr[naxes:]: - ax.set_visible(False) - - if squeeze: - # Reshape the array to have the final desired dimension (nrow,ncol), - # though discarding unneeded dimensions that equal 1. If we only have - # one subplot, just return it instead of a 1-element array. - if nplots == 1: - axes = axarr[0] - else: - axes = axarr.reshape(nrows, ncols).squeeze() - else: - # returned axis array will be always 2-d, even if nrows=ncols=1 - axes = axarr.reshape(nrows, ncols) - - return fig, axes - def _flatten(axes): if not com.is_list_like(axes):