Skip to content

Add cdfplot #5700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def _skip_if_no_scipy():
raise nose.SkipTest("no scipy")


def _skip_if_no_sm():
try:
import statsmodels.api as sm
except ImportError:
raise nose.SkipTest("no statsmodels")


@tm.mplskip
class TestSeriesPlots(tm.TestCase):
def setUp(self):
Expand Down Expand Up @@ -60,6 +67,7 @@ def test_plot(self):
_check_plot_works(self.series[:5].plot, kind='barh')
_check_plot_works(self.series[:10].plot, kind='barh')
_check_plot_works(Series(randn(10)).plot, kind='bar', color='black')
_check_plot_works(Series(randn(10)).plot, kind='cdf')

@slow
def test_plot_figsize_and_title(self):
Expand Down Expand Up @@ -453,6 +461,26 @@ def test_plot_xy(self):
# columns.inferred_type == 'mixed'
# TODO add MultiIndex test

def test_get_plot_kind(self):
from pandas.tools.plotting import (LinePlot, BarPlot, DistributionPlot,
ScatterPlot, _get_plot_kind)
kinds = ['line', 'bar', 'barh', 'scatter']
klasses = [LinePlot, BarPlot, BarPlot, ScatterPlot]
for kind, kls in zip(kinds, klasses):
result = _get_plot_kind(kind)
self.assertEqual(result, kls)

for kind in ['kde', 'cdf']:
result = _get_plot_kind(kind)
self.assertEqual(result.func, DistributionPlot)
self.assertEqual(result.keywords, {'kind': kind})

with tm.assertRaises(ValueError):
_get_plot_kind('scatter', series=True)

with tm.assertRaises(ValueError):
_get_plot_kind('NOT A PLOT KIND')

@slow
def test_xcompat(self):
import pandas as pd
Expand Down
115 changes: 67 additions & 48 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@
def _get_standard_kind(kind):
return {'density': 'kde'}.get(kind, kind)


def _get_plot_kind(kind, series=False):
from functools import partial
plot_kinds = {'kde': partial(DistributionPlot, kind='kde'),
'cdf': partial(DistributionPlot, kind='cdf'),
'bar': BarPlot,
'barh': BarPlot,
'line': LinePlot,
'scatter': ScatterPlot}

if kind == 'scatter' and series:
raise ValueError('Invalid chart type (%s) given for series plot' % kind)
try:
klass = plot_kinds[kind]
return klass
except KeyError:
raise ValueError('Invalid chart type given %s' % kind)


def _get_standard_colors(num_colors=None, colormap=None, color_type='default',
color=None):
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -1174,60 +1193,77 @@ def _get_marked_label(self, label, col_num):
return label


class KdePlot(MPLPlot):
def __init__(self, data, bw_method=None, ind=None, **kwargs):
class DistributionPlot(MPLPlot):
def __init__(self, data, kind, **kwargs):
"""
data : NDFrame
kind : str
`kde` or `cdf`
"""
MPLPlot.__init__(self, data, **kwargs)
self.bw_method=bw_method
self.ind=ind
self.kind = kind
self.kde_kwds = {'ind': self.kwds.pop('ind', None),
'bw_method': self.kwds.pop('bw_method', None)}
self.plotf = self._get_plot_function()
self.colors = self._get_colors()

def _make_plot(self):
from scipy.stats import gaussian_kde
from scipy import __version__ as spv
from distutils.version import LooseVersion
plotf = self._get_plot_function()
colors = self._get_colors()
if self.kind == 'kde':
from scipy.stats import gaussian_kde
from scipy import __version__ as spv
from distutils.version import LooseVersion
else:
import statsmodels.api as sm

for i, (label, y) in enumerate(self._iter_data()):
ax = self._get_ax(i)
style = self._get_style(i, label)

label = com.pprint_thing(label)
ind = self.kde_kwds.get('ind')
bw_method = self.kde_kwds.get('bw_method')

if LooseVersion(spv) >= '0.11.0':
gkde = gaussian_kde(y, bw_method=self.bw_method)
else:
gkde = gaussian_kde(y)
if self.bw_method is not None:
msg = ('bw_method was added in Scipy 0.11.0.' +
' Scipy version in use is %s.' % spv)
warnings.warn(msg)

sample_range = max(y) - min(y)

if self.ind is None:
ind = np.linspace(min(y) - 0.5 * sample_range,
max(y) + 0.5 * sample_range, 1000)
# calculation
if self.kind == 'kde':
if LooseVersion(spv) >= '0.11.0':
gkde = gaussian_kde(y, bw_method=bw_method)
else:
gkde = gaussian_kde(y)
if bw_method is not None: # Is bw_method always a str?
msg = ('bw_method was added in Scipy 0.11.0.' +
' Scipy version in use is %s.' % spv)
warnings.warn(msg)

sample_range = max(y) - min(y)
if ind is None:
ind = np.linspace(min(y) - 0.5 * sample_range,
max(y) + 0.5 * sample_range, 1000)
y = gkde.evaluate(ind)
ax.set_ylabel("Density")
else:
ind = self.ind

ax.set_ylabel("Density")
k = sm.nonparametric.KDEUnivariate(y)
k.fit()
if ind is None:
ind = k.support
y = k.cdf

y = gkde.evaluate(ind)
kwds = self.kwds.copy()
kwds['label'] = label
self._maybe_add_color(colors, kwds, style, i)
self._maybe_add_color(self.colors, kwds, style, i)
if style is None:
args = (ax, ind, y)
else:
args = (ax, ind, y, style)

plotf(*args, **kwds)
self.plotf(*args, **kwds)
ax.grid(self.grid)

def _post_plot_logic(self):
if self.legend:
for ax in self.axes:
ax.legend(loc='best')


class ScatterPlot(MPLPlot):
def __init__(self, data, x, y, **kwargs):
MPLPlot.__init__(self, data, **kwargs)
Expand Down Expand Up @@ -1679,16 +1715,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
ax_or_axes : matplotlib.AxesSubplot or list of them
"""
kind = _get_standard_kind(kind.lower().strip())
if kind == 'line':
klass = LinePlot
elif kind in ('bar', 'barh'):
klass = BarPlot
elif kind == 'kde':
klass = KdePlot
elif kind == 'scatter':
klass = ScatterPlot
else:
raise ValueError('Invalid chart type given %s' % kind)
klass = _get_plot_kind(kind)

if kind == 'scatter':
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
Expand Down Expand Up @@ -1782,15 +1809,7 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None,
See matplotlib documentation online for more on this subject
"""
kind = _get_standard_kind(kind.lower().strip())
if kind == 'line':
klass = LinePlot
elif kind in ('bar', 'barh'):
klass = BarPlot
elif kind == 'kde':
klass = KdePlot
else:
raise ValueError('Invalid chart type given %s' % kind)

klass = _get_plot_kind(kind, series=True)
"""
If no axis is specified, we check whether there are existing figures.
If so, we get the current axis and check whether yaxis ticks are on the
Expand Down