Skip to content

Commit 187a9f7

Browse files
committed
Merge pull request #613 from shoyer/facet-dataset
Ensure xplt.FacetGrid works on Dataset objects
2 parents aca62ab + 11eecab commit 187a9f7

File tree

5 files changed

+125
-71
lines changed

5 files changed

+125
-71
lines changed

doc/plotting.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,9 @@ attributes, both 2d Numpy object arrays.
404404
405405
g.name_dicts
406406
407-
It's possible to select the :py:class:`xray.DataArray` corresponding to the FacetGrid
408-
through the ``name_dicts``.
407+
It's possible to select the :py:class:`xray.DataArray` or
408+
:py:class:`xray.Dataset` corresponding to the FacetGrid through the
409+
``name_dicts``.
409410

410411
.. ipython:: python
411412
@@ -427,6 +428,8 @@ they have been plotted.
427428
@savefig plot_facet_iterator.png height=12in
428429
plt.show()
429430
431+
TODO: add an example of using the ``map`` method to plot dataset variables
432+
(e.g., with ``plt.quiver``).
430433

431434
Maps
432435
----

xray/plot/facetgrid.py

Lines changed: 70 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import division
22

3+
import inspect
34
import warnings
45
import itertools
56
import functools
@@ -176,9 +177,16 @@ def __init__(self, data, col=None, row=None, col_wrap=None,
176177
self._col_wrap = col_wrap
177178
self._x_var = None
178179
self._y_var = None
180+
self._cmap_extend = None
179181
self._mappables = []
180182

181-
self.set_titles()
183+
@property
184+
def _left_axes(self):
185+
return self.axes[:, 0]
186+
187+
@property
188+
def _bottom_axes(self):
189+
return self.axes[-1, :]
182190

183191
def map_dataarray(self, func, x, y, **kwargs):
184192
"""
@@ -201,46 +209,22 @@ def map_dataarray(self, func, x, y, **kwargs):
201209
self : FacetGrid object
202210
203211
"""
204-
205212
# These should be consistent with xray.plot._plot2d
206213
cmap_kwargs = {'plot_data': self.data.values,
207-
'vmin': None,
208-
'vmax': None,
209-
'cmap': None,
210-
'center': None,
211-
'robust': False,
212-
'extend': None,
213214
# MPL default
214215
'levels': 7 if 'contour' in func.__name__ else None,
215216
'filled': func.__name__ != 'contour',
216217
}
217218

218-
# Allow kwargs to override these defaults
219-
# Remove cmap_kwargs from kwargs for now, we will add them back later
220-
for param in list(kwargs):
221-
if param in cmap_kwargs:
222-
cmap_kwargs[param] = kwargs.pop(param)
219+
cmap_args = inspect.getargspec(_determine_cmap_params).args
220+
cmap_kwargs.update((a, kwargs[a]) for a in cmap_args if a in kwargs)
223221

224-
# colormap inference has to happen here since all the data in
225-
# self.data is required to make the right choice
226222
cmap_params = _determine_cmap_params(**cmap_kwargs)
227223

228-
if 'contour' in func.__name__:
229-
# extend is a keyword argument only for contour and contourf, but
230-
# passing it to the colorbar is sufficient for imshow and
231-
# pcolormesh
232-
kwargs['extend'] = cmap_params['extend']
233-
kwargs['levels'] = cmap_params['levels']
234-
235-
defaults = {
236-
'add_colorbar': False,
237-
'add_labels': False,
238-
'norm': cmap_params.pop('cnorm'),
239-
}
240-
241224
# Order is important
242-
defaults.update(cmap_params)
243-
defaults.update(kwargs)
225+
func_kwargs = kwargs.copy()
226+
func_kwargs.update(cmap_params)
227+
func_kwargs.update({'add_colorbar': False, 'add_labels': False})
244228

245229
# Get x, y labels for the first subplot
246230
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
@@ -250,40 +234,67 @@ def map_dataarray(self, func, x, y, **kwargs):
250234
# None is the sentinel value
251235
if d is not None:
252236
subset = self.data.loc[d]
253-
self._mappables.append(func(subset, x, y, ax=ax, **defaults))
237+
mappable = func(subset, x, y, ax=ax, **func_kwargs)
238+
self._mappables.append(mappable)
254239

255-
# Left side labels
256-
for ax in self.axes[:, 0]:
257-
ax.set_ylabel(y)
240+
self._cmap_extend = cmap_params.get('extend')
241+
self._finalize_grid(x, y)
258242

259-
# Bottom labels
260-
for ax in self.axes[-1, :]:
261-
ax.set_xlabel(x)
243+
if kwargs.get('add_colorbar', True):
244+
self.add_colorbar()
245+
246+
return self
262247

248+
def _finalize_grid(self, *axlabels):
249+
"""Finalize the annotations and layout."""
250+
self.set_axis_labels(*axlabels)
251+
self.set_titles()
263252
self.fig.tight_layout()
264253

265-
if self._single_group:
266-
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
267-
if d is None:
268-
ax.set_visible(False)
254+
for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
255+
if namedict is None:
256+
ax.set_visible(False)
269257

270-
# colorbar
271-
if kwargs.get('add_colorbar', True):
272-
cbar = self.fig.colorbar(self._mappables[-1],
273-
ax=list(self.axes.flat),
274-
extend=cmap_params['extend'])
258+
def add_colorbar(self, **kwargs):
259+
"""Draw a colorbar
260+
"""
261+
kwargs = kwargs.copy()
262+
if self._cmap_extend is not None:
263+
kwargs.setdefault('extend', self._cmap_extend)
264+
if getattr(self.data, 'name', None) is not None:
265+
kwargs.setdefault('label', self.data.name)
266+
self.fig.colorbar(self._mappables[-1], ax=list(self.axes.flat),
267+
**kwargs)
268+
return self
275269

276-
if self.data.name:
277-
cbar.set_label(self.data.name, rotation=90,
278-
verticalalignment='bottom')
270+
def set_axis_labels(self, x_var=None, y_var=None):
271+
"""Set axis labels on the left column and bottom row of the grid."""
272+
if x_var is not None:
273+
self._x_var = x_var
274+
self.set_xlabels(x_var)
275+
if y_var is not None:
276+
self._y_var = y_var
277+
self.set_ylabels(y_var)
278+
return self
279279

280-
self._x_var = x
281-
self._y_var = y
280+
def set_xlabels(self, label=None, **kwargs):
281+
"""Label the x axis on the bottom row of the grid."""
282+
if label is None:
283+
label = self._x_var
284+
for ax in self._bottom_axes:
285+
ax.set_xlabel(label, **kwargs)
286+
return self
282287

288+
def set_ylabels(self, label=None, **kwargs):
289+
"""Label the y axis on the left column of the grid."""
290+
if label is None:
291+
label = self._y_var
292+
for ax in self._left_axes:
293+
ax.set_ylabel(label, **kwargs)
283294
return self
284295

285296
def set_titles(self, template="{coord} = {value}", maxchar=30,
286-
fontsize=_FONTSIZE, **kwargs):
297+
**kwargs):
287298
"""
288299
Draw titles either above each facet or on the grid margins.
289300
@@ -293,8 +304,6 @@ def set_titles(self, template="{coord} = {value}", maxchar=30,
293304
Template for plot titles containing {coord} and {value}
294305
maxchar : int
295306
Truncate titles at maxchar
296-
fontsize : string or int
297-
Passed to matplotlib.text
298307
kwargs : keyword args
299308
additional arguments to matplotlib.text
300309
@@ -303,8 +312,9 @@ def set_titles(self, template="{coord} = {value}", maxchar=30,
303312
self: FacetGrid object
304313
305314
"""
315+
import matplotlib as mpl
306316

307-
kwargs['fontsize'] = fontsize
317+
kwargs["size"] = kwargs.pop("size", mpl.rcParams["axes.labelsize"])
308318

309319
nicetitle = functools.partial(_nicetitle, maxchar=maxchar,
310320
template=template)
@@ -394,6 +404,10 @@ def map(self, func, *args, **kwargs):
394404
data = self.data.loc[namedict]
395405
plt.sca(ax)
396406
innerargs = [data[a].values for a in args]
397-
func(*innerargs, **kwargs)
407+
# TODO: is it possible to verify that an artist is mappable?
408+
mappable = func(*innerargs, **kwargs)
409+
self._mappables.append(mappable)
410+
411+
self._finalize_grid(*args[:2])
398412

399413
return self

xray/plot/plot.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,6 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
343343
allargs = locals().copy()
344344
allargs.update(allargs.pop('kwargs'))
345345

346-
# Allows use of better FacetGrid defaults
347-
assert allargs.pop('add_labels')
348-
assert allargs.pop('add_colorbar')
349-
350346
# Need the decorated plotting function
351347
allargs['plotfunc'] = globals()[plotfunc.__name__]
352348

@@ -408,7 +404,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
408404
kwargs['levels'] = cmap_params['levels']
409405

410406
# This allows the user to pass in a custom norm coming via kwargs
411-
kwargs.setdefault('norm', cmap_params['cnorm'])
407+
kwargs.setdefault('norm', cmap_params['norm'])
412408

413409
ax, primitive = plotfunc(xval, yval, zval, ax=ax,
414410
cmap=cmap_params['cmap'],

xray/plot/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
176176
cmap, cnorm = _build_discrete_cmap(cmap, levels, extend, filled)
177177

178178
return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend,
179-
levels=levels, cnorm=cnorm)
179+
levels=levels, norm=cnorm)
180180

181181

182182
def _infer_xy_labels(darray, x, y):

xray/test/test_plot.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def text_in_fig():
3232
return set(alltxt)
3333

3434

35+
def find_possible_colorbars():
36+
# nb. this function also matches meshes from pcolormesh
37+
return plt.gcf().findobj(mpl.collections.QuadMesh)
38+
39+
3540
def substring_in_axes(substring, ax):
3641
'''
3742
Return True if a substring is found anywhere in an axes
@@ -266,15 +271,15 @@ def test_robust(self):
266271
self.assertEqual(cmap_params['cmap'].name, 'viridis')
267272
self.assertEqual(cmap_params['extend'], 'both')
268273
self.assertIsNone(cmap_params['levels'])
269-
self.assertIsNone(cmap_params['cnorm'])
274+
self.assertIsNone(cmap_params['norm'])
270275

271276
def test_center(self):
272277
cmap_params = _determine_cmap_params(self.data, center=0.5)
273278
self.assertEqual(cmap_params['vmax'] - 0.5, 0.5 - cmap_params['vmin'])
274279
self.assertEqual(cmap_params['cmap'], 'RdBu_r')
275280
self.assertEqual(cmap_params['extend'], 'neither')
276281
self.assertIsNone(cmap_params['levels'])
277-
self.assertIsNone(cmap_params['cnorm'])
282+
self.assertIsNone(cmap_params['norm'])
278283

279284
def test_integer_levels(self):
280285
data = self.data + 1
@@ -285,7 +290,7 @@ def test_integer_levels(self):
285290
self.assertEqual(cmap_params['cmap'].name, 'Blues')
286291
self.assertEqual(cmap_params['extend'], 'neither')
287292
self.assertEqual(cmap_params['cmap'].N, 5)
288-
self.assertEqual(cmap_params['cnorm'].N, 6)
293+
self.assertEqual(cmap_params['norm'].N, 6)
289294

290295
cmap_params = _determine_cmap_params(data, levels=5,
291296
vmin=0.5, vmax=1.5)
@@ -302,7 +307,7 @@ def test_list_levels(self):
302307
self.assertEqual(cmap_params['vmin'], 0)
303308
self.assertEqual(cmap_params['vmax'], 5)
304309
self.assertEqual(cmap_params['cmap'].N, 5)
305-
self.assertEqual(cmap_params['cnorm'].N, 6)
310+
self.assertEqual(cmap_params['norm'].N, 6)
306311

307312
for wrap_levels in [list, np.array, pd.Index, DataArray]:
308313
cmap_params = _determine_cmap_params(
@@ -790,9 +795,7 @@ def test_colorbar(self):
790795
clim = np.array(image.get_clim())
791796
self.assertTrue(np.allclose(expected, clim))
792797

793-
# There's only one colorbar
794-
cbar = plt.gcf().findobj(mpl.collections.QuadMesh)
795-
self.assertEqual(1, len(cbar))
798+
self.assertEqual(1, len(find_possible_colorbars()))
796799

797800
def test_empty_cell(self):
798801
g = xplt.FacetGrid(self.darray, col='z', col_wrap=2)
@@ -885,6 +888,44 @@ def test_map(self):
885888
self.g.map(plt.contourf, 'x', 'y', Ellipsis)
886889
self.g.map(lambda: None)
887890

891+
def test_map_dataset(self):
892+
g = xplt.FacetGrid(self.darray.to_dataset(name='foo'), col='z')
893+
g.map(plt.contourf, 'x', 'y', 'foo')
894+
895+
alltxt = text_in_fig()
896+
for label in ['x', 'y']:
897+
self.assertIn(label, alltxt)
898+
# everything has a label
899+
self.assertNotIn('None', alltxt)
900+
901+
# colorbar can't be inferred automatically
902+
self.assertNotIn('foo', alltxt)
903+
self.assertEqual(0, len(find_possible_colorbars()))
904+
905+
g.add_colorbar(label='colors!')
906+
self.assertIn('colors!', text_in_fig())
907+
self.assertEqual(1, len(find_possible_colorbars()))
908+
909+
def test_set_axis_labels(self):
910+
g = self.g.map_dataarray(xplt.contourf, 'x', 'y')
911+
g.set_axis_labels('longitude', 'latitude')
912+
alltxt = text_in_fig()
913+
for label in ['longitude', 'latitude']:
914+
self.assertIn(label, alltxt)
915+
916+
def test_facetgrid_colorbar(self):
917+
a = easy_array((10, 15, 4))
918+
d = DataArray(a, dims=['y', 'x', 'z'], name='foo')
919+
920+
d.plot.imshow(x='x', y='y', col='z')
921+
self.assertEqual(1, len(find_possible_colorbars()))
922+
923+
d.plot.imshow(x='x', y='y', col='z', add_colorbar=True)
924+
self.assertEqual(1, len(find_possible_colorbars()))
925+
926+
d.plot.imshow(x='x', y='y', col='z', add_colorbar=False)
927+
self.assertEqual(0, len(find_possible_colorbars()))
928+
888929

889930
class TestFacetGrid4d(PlotTestCase):
890931

0 commit comments

Comments
 (0)