Skip to content

Commit 7f6eaec

Browse files
committed
Simplify FacetGrid.map_dataarary
1 parent 6d819c9 commit 7f6eaec

File tree

4 files changed

+10
-29
lines changed

4 files changed

+10
-29
lines changed

xray/plot/facetgrid.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import division
22

3+
import inspect
34
import warnings
45
import itertools
56
import functools
@@ -208,46 +209,26 @@ def map_dataarray(self, func, x, y, **kwargs):
208209
self : FacetGrid object
209210
210211
"""
211-
212212
# These should be consistent with xray.plot._plot2d
213213
cmap_kwargs = {'plot_data': self.data.values,
214-
'vmin': None,
215-
'vmax': None,
216-
'cmap': None,
217-
'center': None,
218-
'robust': False,
219-
'extend': None,
220214
# MPL default
221215
'levels': 7 if 'contour' in func.__name__ else None,
222216
'filled': func.__name__ != 'contour',
223217
}
224218

225-
# Allow kwargs to override these defaults
226-
# Remove cmap_kwargs from kwargs for now, we will add them back later
227-
for param in list(kwargs):
228-
if param in cmap_kwargs:
229-
cmap_kwargs[param] = kwargs.pop(param)
219+
cmap_args = inspect.getargspec(_determine_cmap_params).args
220+
cmap_kwargs.update((a, kwargs[a]) for a in cmap_args if a in kwargs)
230221

231-
# colormap inference has to happen here since all the data in
232-
# self.data is required to make the right choice
233222
cmap_params = _determine_cmap_params(**cmap_kwargs)
234223

235-
if 'contour' in func.__name__:
236-
# extend is a keyword argument only for contour and contourf, but
237-
# passing it to the colorbar is sufficient for imshow and
238-
# pcolormesh
239-
kwargs['extend'] = cmap_params['extend']
240-
kwargs['levels'] = cmap_params['levels']
241-
242224
defaults = {
243225
'add_colorbar': False,
244226
'add_labels': False,
245-
'norm': cmap_params.pop('cnorm'),
246227
}
247228

248229
# Order is important
249-
defaults.update(cmap_params)
250230
defaults.update(kwargs)
231+
defaults.update(cmap_params)
251232

252233
# Get x, y labels for the first subplot
253234
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],

xray/plot/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def newplotfunc(darray, x=None, y=None, ax=None, row=None, col=None,
408408
kwargs['levels'] = cmap_params['levels']
409409

410410
# This allows the user to pass in a custom norm coming via kwargs
411-
kwargs.setdefault('norm', cmap_params['cnorm'])
411+
kwargs.setdefault('norm', cmap_params['norm'])
412412

413413
ax, primitive = plotfunc(xval, yval, zval, ax=ax,
414414
cmap=cmap_params['cmap'],

xray/plot/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
176176
cmap, cnorm = _build_discrete_cmap(cmap, levels, extend, filled)
177177

178178
return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend,
179-
levels=levels, cnorm=cnorm)
179+
levels=levels, norm=cnorm)
180180

181181

182182
def _infer_xy_labels(darray, x, y):

xray/test/test_plot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,15 @@ def test_robust(self):
270270
self.assertEqual(cmap_params['cmap'].name, 'viridis')
271271
self.assertEqual(cmap_params['extend'], 'both')
272272
self.assertIsNone(cmap_params['levels'])
273-
self.assertIsNone(cmap_params['cnorm'])
273+
self.assertIsNone(cmap_params['norm'])
274274

275275
def test_center(self):
276276
cmap_params = _determine_cmap_params(self.data, center=0.5)
277277
self.assertEqual(cmap_params['vmax'] - 0.5, 0.5 - cmap_params['vmin'])
278278
self.assertEqual(cmap_params['cmap'], 'RdBu_r')
279279
self.assertEqual(cmap_params['extend'], 'neither')
280280
self.assertIsNone(cmap_params['levels'])
281-
self.assertIsNone(cmap_params['cnorm'])
281+
self.assertIsNone(cmap_params['norm'])
282282

283283
def test_integer_levels(self):
284284
data = self.data + 1
@@ -289,7 +289,7 @@ def test_integer_levels(self):
289289
self.assertEqual(cmap_params['cmap'].name, 'Blues')
290290
self.assertEqual(cmap_params['extend'], 'neither')
291291
self.assertEqual(cmap_params['cmap'].N, 5)
292-
self.assertEqual(cmap_params['cnorm'].N, 6)
292+
self.assertEqual(cmap_params['norm'].N, 6)
293293

294294
cmap_params = _determine_cmap_params(data, levels=5,
295295
vmin=0.5, vmax=1.5)
@@ -306,7 +306,7 @@ def test_list_levels(self):
306306
self.assertEqual(cmap_params['vmin'], 0)
307307
self.assertEqual(cmap_params['vmax'], 5)
308308
self.assertEqual(cmap_params['cmap'].N, 5)
309-
self.assertEqual(cmap_params['cnorm'].N, 6)
309+
self.assertEqual(cmap_params['norm'].N, 6)
310310

311311
for wrap_levels in [list, np.array, pd.Index, DataArray]:
312312
cmap_params = _determine_cmap_params(

0 commit comments

Comments
 (0)