diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index de35894b017be..1f799c23c5396 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -774,7 +774,12 @@ class MPLPlot(object): data : """ - _kind = 'base' + + @property + def _kind(self): + """Specify kind str. Must be overridden in child class""" + raise NotImplementedError + _layout_type = 'vertical' _default_rot = 0 orientation = None @@ -938,7 +943,10 @@ def generate(self): self._make_plot() self._add_table() self._make_legend() - self._post_plot_logic() + + for ax in self.axes: + self._post_plot_logic_common(ax, self.data) + self._post_plot_logic(ax, self.data) self._adorn_subplots() def _args_adjust(self): @@ -1055,12 +1063,34 @@ def _add_table(self): ax = self._get_ax(0) table(ax, data) - def _post_plot_logic(self): + def _post_plot_logic_common(self, ax, data): + """Common post process for each axes""" + labels = [com.pprint_thing(key) for key in data.index] + labels = dict(zip(range(len(data.index)), labels)) + + if self.orientation == 'vertical' or self.orientation is None: + if self._need_to_set_index: + xticklabels = [labels.get(x, '') for x in ax.get_xticks()] + ax.set_xticklabels(xticklabels) + self._apply_axis_properties(ax.xaxis, rot=self.rot, + fontsize=self.fontsize) + self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize) + elif self.orientation == 'horizontal': + if self._need_to_set_index: + yticklabels = [labels.get(y, '') for y in ax.get_yticks()] + ax.set_yticklabels(yticklabels) + self._apply_axis_properties(ax.yaxis, rot=self.rot, + fontsize=self.fontsize) + self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize) + else: # pragma no cover + raise ValueError + + def _post_plot_logic(self, ax, data): + """Post process for each axes. Overridden in child classes""" pass def _adorn_subplots(self): - to_adorn = self.axes - + """Common post process unrelated to data""" if len(self.axes) > 0: all_axes = self._get_axes() nrows, ncols = self._get_axes_layout() @@ -1069,7 +1099,7 @@ def _adorn_subplots(self): ncols=ncols, sharex=self.sharex, sharey=self.sharey) - for ax in to_adorn: + for ax in self.axes: if self.yticks is not None: ax.set_yticks(self.yticks) @@ -1090,25 +1120,6 @@ def _adorn_subplots(self): else: self.axes[0].set_title(self.title) - labels = [com.pprint_thing(key) for key in self.data.index] - labels = dict(zip(range(len(self.data.index)), labels)) - - for ax in self.axes: - if self.orientation == 'vertical' or self.orientation is None: - if self._need_to_set_index: - xticklabels = [labels.get(x, '') for x in ax.get_xticks()] - ax.set_xticklabels(xticklabels) - self._apply_axis_properties(ax.xaxis, rot=self.rot, - fontsize=self.fontsize) - self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize) - elif self.orientation == 'horizontal': - if self._need_to_set_index: - yticklabels = [labels.get(y, '') for y in ax.get_yticks()] - ax.set_yticklabels(yticklabels) - self._apply_axis_properties(ax.yaxis, rot=self.rot, - fontsize=self.fontsize) - self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize) - def _apply_axis_properties(self, axis, rot=None, fontsize=None): labels = axis.get_majorticklabels() + axis.get_minorticklabels() for label in labels: @@ -1419,34 +1430,48 @@ def _get_axes_layout(self): y_set.add(points[0][1]) return (len(y_set), len(x_set)) -class ScatterPlot(MPLPlot): - _kind = 'scatter' + +class PlanePlot(MPLPlot): + """ + Abstract class for plotting on plane, currently scatter and hexbin. + """ + _layout_type = 'single' - def __init__(self, data, x, y, c=None, **kwargs): + def __init__(self, data, x, y, **kwargs): MPLPlot.__init__(self, data, **kwargs) if x is None or y is None: - raise ValueError( 'scatter requires and x and y column') + raise ValueError(self._kind + ' requires and x and y column') if com.is_integer(x) and not self.data.columns.holds_integer(): x = self.data.columns[x] if com.is_integer(y) and not self.data.columns.holds_integer(): y = self.data.columns[y] - if com.is_integer(c) and not self.data.columns.holds_integer(): - c = self.data.columns[c] self.x = x self.y = y - self.c = c @property def nseries(self): return 1 + def _post_plot_logic(self, ax, data): + x, y = self.x, self.y + ax.set_ylabel(com.pprint_thing(y)) + ax.set_xlabel(com.pprint_thing(x)) + + +class ScatterPlot(PlanePlot): + _kind = 'scatter' + + def __init__(self, data, x, y, c=None, **kwargs): + super(ScatterPlot, self).__init__(data, x, y, **kwargs) + if com.is_integer(c) and not self.data.columns.holds_integer(): + c = self.data.columns[c] + self.c = c + def _make_plot(self): import matplotlib as mpl mpl_ge_1_3_1 = str(mpl.__version__) >= LooseVersion('1.3.1') - import matplotlib.pyplot as plt - x, y, c, data = self.x, self.y, self.c, self.data ax = self.axes[0] @@ -1457,7 +1482,7 @@ def _make_plot(self): # pandas uses colormap, matplotlib uses cmap. cmap = self.colormap or 'Greys' - cmap = plt.cm.get_cmap(cmap) + cmap = self.plt.cm.get_cmap(cmap) if c is None: c_values = self.plt.rcParams['patch.facecolor'] @@ -1491,46 +1516,22 @@ def _make_plot(self): err_kwds['ecolor'] = scatter.get_facecolor()[0] ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds) - def _post_plot_logic(self): - ax = self.axes[0] - x, y = self.x, self.y - ax.set_ylabel(com.pprint_thing(y)) - ax.set_xlabel(com.pprint_thing(x)) - -class HexBinPlot(MPLPlot): +class HexBinPlot(PlanePlot): _kind = 'hexbin' - _layout_type = 'single' def __init__(self, data, x, y, C=None, **kwargs): - MPLPlot.__init__(self, data, **kwargs) - - if x is None or y is None: - raise ValueError('hexbin requires and x and y column') - if com.is_integer(x) and not self.data.columns.holds_integer(): - x = self.data.columns[x] - if com.is_integer(y) and not self.data.columns.holds_integer(): - y = self.data.columns[y] - + super(HexBinPlot, self).__init__(data, x, y, **kwargs) if com.is_integer(C) and not self.data.columns.holds_integer(): C = self.data.columns[C] - - self.x = x - self.y = y self.C = C - @property - def nseries(self): - return 1 - def _make_plot(self): - import matplotlib.pyplot as plt - x, y, data, C = self.x, self.y, self.data, self.C ax = self.axes[0] # pandas uses colormap, matplotlib uses cmap. cmap = self.colormap or 'BuGn' - cmap = plt.cm.get_cmap(cmap) + cmap = self.plt.cm.get_cmap(cmap) cb = self.kwds.pop('colorbar', True) if C is None: @@ -1547,12 +1548,6 @@ def _make_plot(self): def _make_legend(self): pass - def _post_plot_logic(self): - ax = self.axes[0] - x, y = self.x, self.y - ax.set_ylabel(com.pprint_thing(y)) - ax.set_xlabel(com.pprint_thing(x)) - class LinePlot(MPLPlot): _kind = 'line' @@ -1685,26 +1680,23 @@ def _update_stacker(cls, ax, stacking_id, values): elif (values <= 0).all(): ax._stacker_neg_prior[stacking_id] += values - def _post_plot_logic(self): - df = self.data - + def _post_plot_logic(self, ax, data): condition = (not self._use_dynamic_x() - and df.index.is_all_dates + and data.index.is_all_dates and not self.subplots or (self.subplots and self.sharex)) index_name = self._get_index_name() - for ax in self.axes: - if condition: - # irregular TS rotated 30 deg. by default - # probably a better place to check / set this. - if not self._rot_set: - self.rot = 30 - format_date_labels(ax, rot=self.rot) + if condition: + # irregular TS rotated 30 deg. by default + # probably a better place to check / set this. + if not self._rot_set: + self.rot = 30 + format_date_labels(ax, rot=self.rot) - if index_name is not None and self.use_index: - ax.set_xlabel(index_name) + if index_name is not None and self.use_index: + ax.set_xlabel(index_name) class AreaPlot(LinePlot): @@ -1758,16 +1750,14 @@ def _add_legend_handle(self, handle, label, index=None): handle = Rectangle((0, 0), 1, 1, fc=handle.get_color(), alpha=alpha) LinePlot._add_legend_handle(self, handle, label, index=index) - def _post_plot_logic(self): - LinePlot._post_plot_logic(self) + def _post_plot_logic(self, ax, data): + LinePlot._post_plot_logic(self, ax, data) if self.ylim is None: - if (self.data >= 0).all().all(): - for ax in self.axes: - ax.set_ylim(0, None) - elif (self.data <= 0).all().all(): - for ax in self.axes: - ax.set_ylim(None, 0) + if (data >= 0).all().all(): + ax.set_ylim(0, None) + elif (data <= 0).all().all(): + ax.set_ylim(None, 0) class BarPlot(MPLPlot): @@ -1865,19 +1855,17 @@ def _make_plot(self): start=start, label=label, log=self.log, **kwds) self._add_legend_handle(rect, label, index=i) - def _post_plot_logic(self): - for ax in self.axes: - if self.use_index: - str_index = [com.pprint_thing(key) for key in self.data.index] - else: - str_index = [com.pprint_thing(key) for key in - range(self.data.shape[0])] - name = self._get_index_name() + def _post_plot_logic(self, ax, data): + if self.use_index: + str_index = [com.pprint_thing(key) for key in data.index] + else: + str_index = [com.pprint_thing(key) for key in range(data.shape[0])] + name = self._get_index_name() - s_edge = self.ax_pos[0] - 0.25 + self.lim_offset - e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset + s_edge = self.ax_pos[0] - 0.25 + self.lim_offset + e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset - self._decorate_ticks(ax, name, str_index, s_edge, e_edge) + self._decorate_ticks(ax, name, str_index, s_edge, e_edge) def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge): ax.set_xlim((start_edge, end_edge)) @@ -1975,13 +1963,11 @@ def _make_plot_keywords(self, kwds, y): kwds['bins'] = self.bins return kwds - def _post_plot_logic(self): + def _post_plot_logic(self, ax, data): if self.orientation == 'horizontal': - for ax in self.axes: - ax.set_xlabel('Frequency') + ax.set_xlabel('Frequency') else: - for ax in self.axes: - ax.set_ylabel('Frequency') + ax.set_ylabel('Frequency') @property def orientation(self): @@ -2038,9 +2024,8 @@ def _make_plot_keywords(self, kwds, y): kwds['ind'] = self._get_ind(y) return kwds - def _post_plot_logic(self): - for ax in self.axes: - ax.set_ylabel('Density') + def _post_plot_logic(self, ax, data): + ax.set_ylabel('Density') class PiePlot(MPLPlot): @@ -2242,7 +2227,7 @@ def _set_ticklabels(self, ax, labels): def _make_legend(self): pass - def _post_plot_logic(self): + def _post_plot_logic(self, ax, data): pass @property