Skip to content

Commit 11eecab

Browse files
committed
Adjust add_colorbar option for facets
1 parent 2804247 commit 11eecab

File tree

3 files changed

+27
-24
lines changed

3 files changed

+27
-24
lines changed

xray/plot/facetgrid.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,10 @@ def map_dataarray(self, func, x, y, **kwargs):
221221

222222
cmap_params = _determine_cmap_params(**cmap_kwargs)
223223

224-
defaults = {
225-
'add_colorbar': False,
226-
'add_labels': False,
227-
}
228-
229224
# Order is important
230-
defaults.update(kwargs)
231-
defaults.update(cmap_params)
225+
func_kwargs = kwargs.copy()
226+
func_kwargs.update(cmap_params)
227+
func_kwargs.update({'add_colorbar': False, 'add_labels': False})
232228

233229
# Get x, y labels for the first subplot
234230
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
@@ -238,13 +234,12 @@ def map_dataarray(self, func, x, y, **kwargs):
238234
# None is the sentinel value
239235
if d is not None:
240236
subset = self.data.loc[d]
241-
mappable = func(subset, x, y, ax=ax, **defaults)
237+
mappable = func(subset, x, y, ax=ax, **func_kwargs)
242238
self._mappables.append(mappable)
243239

244-
self._cmap_extend = defaults.get('extend')
240+
self._cmap_extend = cmap_params.get('extend')
245241
self._finalize_grid(x, y)
246242

247-
# colorbar
248243
if kwargs.get('add_colorbar', True):
249244
self.add_colorbar()
250245

@@ -266,12 +261,10 @@ def add_colorbar(self, **kwargs):
266261
kwargs = kwargs.copy()
267262
if self._cmap_extend is not None:
268263
kwargs.setdefault('extend', self._cmap_extend)
269-
cbar = self.fig.colorbar(self._mappables[-1],
270-
ax=list(self.axes.flat),
271-
**kwargs)
272-
if getattr(self.data, 'name', False):
273-
cbar.set_label(self.data.name, rotation=90,
274-
verticalalignment='bottom')
264+
if getattr(self.data, 'name', None) is not None:
265+
kwargs.setdefault('label', self.data.name)
266+
self.fig.colorbar(self._mappables[-1], ax=list(self.axes.flat),
267+
**kwargs)
275268
return self
276269

277270
def set_axis_labels(self, x_var=None, y_var=None):

xray/plot/plot.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,6 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
343343
allargs = locals().copy()
344344
allargs.update(allargs.pop('kwargs'))
345345

346-
# Allows use of better FacetGrid defaults
347-
assert allargs.pop('add_labels')
348-
assert allargs.pop('add_colorbar')
349-
350346
# Need the decorated plotting function
351347
allargs['plotfunc'] = globals()[plotfunc.__name__]
352348

xray/test/test_plot.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def text_in_fig():
3232
return set(alltxt)
3333

3434

35-
def find_colorbars():
35+
def find_possible_colorbars():
36+
# nb. this function also matches meshes from pcolormesh
3637
return plt.gcf().findobj(mpl.collections.QuadMesh)
3738

3839

@@ -794,7 +795,7 @@ def test_colorbar(self):
794795
clim = np.array(image.get_clim())
795796
self.assertTrue(np.allclose(expected, clim))
796797

797-
self.assertEqual(1, len(find_colorbars()))
798+
self.assertEqual(1, len(find_possible_colorbars()))
798799

799800
def test_empty_cell(self):
800801
g = xplt.FacetGrid(self.darray, col='z', col_wrap=2)
@@ -899,11 +900,11 @@ def test_map_dataset(self):
899900

900901
# colorbar can't be inferred automatically
901902
self.assertNotIn('foo', alltxt)
902-
self.assertEqual(0, len(find_colorbars()))
903+
self.assertEqual(0, len(find_possible_colorbars()))
903904

904905
g.add_colorbar(label='colors!')
905906
self.assertIn('colors!', text_in_fig())
906-
self.assertEqual(1, len(find_colorbars()))
907+
self.assertEqual(1, len(find_possible_colorbars()))
907908

908909
def test_set_axis_labels(self):
909910
g = self.g.map_dataarray(xplt.contourf, 'x', 'y')
@@ -912,6 +913,19 @@ def test_set_axis_labels(self):
912913
for label in ['longitude', 'latitude']:
913914
self.assertIn(label, alltxt)
914915

916+
def test_facetgrid_colorbar(self):
917+
a = easy_array((10, 15, 4))
918+
d = DataArray(a, dims=['y', 'x', 'z'], name='foo')
919+
920+
d.plot.imshow(x='x', y='y', col='z')
921+
self.assertEqual(1, len(find_possible_colorbars()))
922+
923+
d.plot.imshow(x='x', y='y', col='z', add_colorbar=True)
924+
self.assertEqual(1, len(find_possible_colorbars()))
925+
926+
d.plot.imshow(x='x', y='y', col='z', add_colorbar=False)
927+
self.assertEqual(0, len(find_possible_colorbars()))
928+
915929

916930
class TestFacetGrid4d(PlotTestCase):
917931

0 commit comments

Comments
 (0)