diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 1fee318059b7f..85811c7bbcf46 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -28,6 +28,13 @@ def _skip_if_no_scipy(): raise nose.SkipTest("no scipy") +def _skip_if_no_sm(): + try: + import statsmodels.api as sm + except ImportError: + raise nose.SkipTest("no statsmodels") + + @tm.mplskip class TestSeriesPlots(tm.TestCase): def setUp(self): @@ -60,6 +67,7 @@ def test_plot(self): _check_plot_works(self.series[:5].plot, kind='barh') _check_plot_works(self.series[:10].plot, kind='barh') _check_plot_works(Series(randn(10)).plot, kind='bar', color='black') + _check_plot_works(Series(randn(10)).plot, kind='cdf') @slow def test_plot_figsize_and_title(self): @@ -453,6 +461,26 @@ def test_plot_xy(self): # columns.inferred_type == 'mixed' # TODO add MultiIndex test + def test_get_plot_kind(self): + from pandas.tools.plotting import (LinePlot, BarPlot, DistributionPlot, + ScatterPlot, _get_plot_kind) + kinds = ['line', 'bar', 'barh', 'scatter'] + klasses = [LinePlot, BarPlot, BarPlot, ScatterPlot] + for kind, kls in zip(kinds, klasses): + result = _get_plot_kind(kind) + self.assertEqual(result, kls) + + for kind in ['kde', 'cdf']: + result = _get_plot_kind(kind) + self.assertEqual(result.func, DistributionPlot) + self.assertEqual(result.keywords, {'kind': kind}) + + with tm.assertRaises(ValueError): + _get_plot_kind('scatter', series=True) + + with tm.assertRaises(ValueError): + _get_plot_kind('NOT A PLOT KIND') + @slow def test_xcompat(self): import pandas as pd diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 9984c3fd76f81..32a623d9f9148 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -94,6 +94,25 @@ def _get_standard_kind(kind): return {'density': 'kde'}.get(kind, kind) + +def _get_plot_kind(kind, series=False): + from functools import partial + plot_kinds = {'kde': partial(DistributionPlot, kind='kde'), + 'cdf': partial(DistributionPlot, kind='cdf'), + 'bar': BarPlot, + 'barh': BarPlot, + 'line': LinePlot, + 'scatter': ScatterPlot} + + if kind == 'scatter' and series: + raise ValueError('Invalid chart type (%s) given for series plot' % kind) + try: + klass = plot_kinds[kind] + return klass + except KeyError: + raise ValueError('Invalid chart type given %s' % kind) + + def _get_standard_colors(num_colors=None, colormap=None, color_type='default', color=None): import matplotlib.pyplot as plt @@ -1174,53 +1193,69 @@ def _get_marked_label(self, label, col_num): return label -class KdePlot(MPLPlot): - def __init__(self, data, bw_method=None, ind=None, **kwargs): +class DistributionPlot(MPLPlot): + def __init__(self, data, kind, **kwargs): + """ + data : NDFrame + kind : str + `kde` or `cdf` + """ MPLPlot.__init__(self, data, **kwargs) - self.bw_method=bw_method - self.ind=ind + self.kind = kind + self.kde_kwds = {'ind': self.kwds.pop('ind', None), + 'bw_method': self.kwds.pop('bw_method', None)} + self.plotf = self._get_plot_function() + self.colors = self._get_colors() def _make_plot(self): - from scipy.stats import gaussian_kde - from scipy import __version__ as spv - from distutils.version import LooseVersion - plotf = self._get_plot_function() - colors = self._get_colors() + if self.kind == 'kde': + from scipy.stats import gaussian_kde + from scipy import __version__ as spv + from distutils.version import LooseVersion + else: + import statsmodels.api as sm + for i, (label, y) in enumerate(self._iter_data()): ax = self._get_ax(i) style = self._get_style(i, label) label = com.pprint_thing(label) + ind = self.kde_kwds.get('ind') + bw_method = self.kde_kwds.get('bw_method') - if LooseVersion(spv) >= '0.11.0': - gkde = gaussian_kde(y, bw_method=self.bw_method) - else: - gkde = gaussian_kde(y) - if self.bw_method is not None: - msg = ('bw_method was added in Scipy 0.11.0.' + - ' Scipy version in use is %s.' % spv) - warnings.warn(msg) - - sample_range = max(y) - min(y) - - if self.ind is None: - ind = np.linspace(min(y) - 0.5 * sample_range, - max(y) + 0.5 * sample_range, 1000) + # calculation + if self.kind == 'kde': + if LooseVersion(spv) >= '0.11.0': + gkde = gaussian_kde(y, bw_method=bw_method) + else: + gkde = gaussian_kde(y) + if bw_method is not None: # Is bw_method always a str? + msg = ('bw_method was added in Scipy 0.11.0.' + + ' Scipy version in use is %s.' % spv) + warnings.warn(msg) + + sample_range = max(y) - min(y) + if ind is None: + ind = np.linspace(min(y) - 0.5 * sample_range, + max(y) + 0.5 * sample_range, 1000) + y = gkde.evaluate(ind) + ax.set_ylabel("Density") else: - ind = self.ind - - ax.set_ylabel("Density") + k = sm.nonparametric.KDEUnivariate(y) + k.fit() + if ind is None: + ind = k.support + y = k.cdf - y = gkde.evaluate(ind) kwds = self.kwds.copy() kwds['label'] = label - self._maybe_add_color(colors, kwds, style, i) + self._maybe_add_color(self.colors, kwds, style, i) if style is None: args = (ax, ind, y) else: args = (ax, ind, y, style) - plotf(*args, **kwds) + self.plotf(*args, **kwds) ax.grid(self.grid) def _post_plot_logic(self): @@ -1228,6 +1263,7 @@ def _post_plot_logic(self): for ax in self.axes: ax.legend(loc='best') + class ScatterPlot(MPLPlot): def __init__(self, data, x, y, **kwargs): MPLPlot.__init__(self, data, **kwargs) @@ -1679,16 +1715,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, ax_or_axes : matplotlib.AxesSubplot or list of them """ kind = _get_standard_kind(kind.lower().strip()) - if kind == 'line': - klass = LinePlot - elif kind in ('bar', 'barh'): - klass = BarPlot - elif kind == 'kde': - klass = KdePlot - elif kind == 'scatter': - klass = ScatterPlot - else: - raise ValueError('Invalid chart type given %s' % kind) + klass = _get_plot_kind(kind) if kind == 'scatter': plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots, @@ -1782,15 +1809,7 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None, See matplotlib documentation online for more on this subject """ kind = _get_standard_kind(kind.lower().strip()) - if kind == 'line': - klass = LinePlot - elif kind in ('bar', 'barh'): - klass = BarPlot - elif kind == 'kde': - klass = KdePlot - else: - raise ValueError('Invalid chart type given %s' % kind) - + klass = _get_plot_kind(kind, series=True) """ If no axis is specified, we check whether there are existing figures. If so, we get the current axis and check whether yaxis ticks are on the