Skip to content

Commit 5472fb5

Browse files
author
Joe Hamman
authored
pass dask compute/persist args through from load/compute/perist (#1543)
* pass dask compute/persist args through from load/compute/perist * fix test and whatsnew note * test dask compute args with mock * use as_compatible_data instead of np.asarray * requires dask * update setup.py * cleanup imports
1 parent 9a8e2c5 commit 5472fb5

11 files changed

+158
-21
lines changed

ci/requirements-py27-cdat+pynio.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies:
1616
- pathlib2
1717
- pynio
1818
- pytest
19+
- mock
1920
- scipy
2021
- seaborn
2122
- toolz

ci/requirements-py27-min.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ name: test_env
22
dependencies:
33
- python=2.7
44
- pytest
5+
- mock
56
- numpy==1.11
67
- pandas==0.18.0
78
- pip:

ci/requirements-py27-windows.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- netcdf4
1212
- pathlib2
1313
- pytest
14+
- mock
1415
- numpy
1516
- pandas
1617
- scipy

doc/installing.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ pandas) installed first. Then, install xarray with pip::
7373

7474
$ pip install xarray
7575

76-
To run the test suite after installing xarray, install
77-
`py.test <https://pytest.org>`__ (``pip install pytest``) and run
76+
Testing
77+
-------
78+
79+
To run the test suite after installing xarray, first install (via pypi or conda)
80+
- `py.test <https://pytest.org>`__: Simple unit testing library
81+
- `mock <https://pypi.python.org/pypi/mock>`__: additional testing library required for python version 2
82+
83+
and run
7884
``py.test --pyargs xarray``.

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ Enhancements
9797
other means (:issue:`1459`).
9898
By `Ryan May <https://github.com/dopplershift>`_.
9999

100+
- Support passing keyword arguments to ``load``, ``compute``, and ``persist``
101+
methods. Any keyword arguments supplied to these methods are passed on to
102+
the corresponding dask function (:issue:`1523`).
103+
By `Joe Hamman <https://github.com/jhamman>`_.
100104
- Encoding attributes are now preserved when xarray objects are concatenated.
101105
The encoding is copied from the first object (:issue:`1297`).
102106
By `Joe Hamman <https://github.com/jhamman>`_ and

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
INSTALL_REQUIRES = ['numpy >= 1.11', 'pandas >= 0.18.0']
3939
TESTS_REQUIRE = ['pytest >= 2.7.1']
40+
if sys.version_info[0] < 3:
41+
TESTS_REQUIRE.append('mock')
4042

4143
DESCRIPTION = "N-D labeled arrays and datasets in Python"
4244
LONG_DESCRIPTION = """

xarray/core/dataarray.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -565,22 +565,31 @@ def reset_coords(self, names=None, drop=False, inplace=False):
565565
dataset[self.name] = self.variable
566566
return dataset
567567

568-
def load(self):
568+
def load(self, **kwargs):
569569
"""Manually trigger loading of this array's data from disk or a
570570
remote source into memory and return this array.
571571
572572
Normally, it should not be necessary to call this method in user code,
573573
because all xarray functions should either work on deferred data or
574574
load data automatically. However, this method can be necessary when
575575
working with many file objects on disk.
576+
577+
Parameters
578+
----------
579+
**kwargs : dict
580+
Additional keyword arguments passed on to ``dask.array.compute``.
581+
582+
See Also
583+
--------
584+
dask.array.compute
576585
"""
577-
ds = self._to_temp_dataset().load()
586+
ds = self._to_temp_dataset().load(**kwargs)
578587
new = self._from_temp_dataset(ds)
579588
self._variable = new._variable
580589
self._coords = new._coords
581590
return self
582591

583-
def compute(self):
592+
def compute(self, **kwargs):
584593
"""Manually trigger loading of this array's data from disk or a
585594
remote source into memory and return a new array. The original is
586595
left unaltered.
@@ -589,18 +598,36 @@ def compute(self):
589598
because all xarray functions should either work on deferred data or
590599
load data automatically. However, this method can be necessary when
591600
working with many file objects on disk.
601+
602+
Parameters
603+
----------
604+
**kwargs : dict
605+
Additional keyword arguments passed on to ``dask.array.compute``.
606+
607+
See Also
608+
--------
609+
dask.array.compute
592610
"""
593611
new = self.copy(deep=False)
594-
return new.load()
612+
return new.load(**kwargs)
595613

596-
def persist(self):
614+
def persist(self, **kwargs):
597615
""" Trigger computation in constituent dask arrays
598616
599617
This keeps them as dask arrays but encourages them to keep data in
600618
memory. This is particularly useful when on a distributed machine.
601619
When on a single machine consider using ``.compute()`` instead.
620+
621+
Parameters
622+
----------
623+
**kwargs : dict
624+
Additional keyword arguments passed on to ``dask.persist``.
625+
626+
See Also
627+
--------
628+
dask.persist
602629
"""
603-
ds = self._to_temp_dataset().persist()
630+
ds = self._to_temp_dataset().persist(**kwargs)
604631
return self._from_temp_dataset(ds)
605632

606633
def copy(self, deep=True):

xarray/core/dataset.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -445,14 +445,23 @@ def sizes(self):
445445
"""
446446
return self.dims
447447

448-
def load(self):
448+
def load(self, **kwargs):
449449
"""Manually trigger loading of this dataset's data from disk or a
450450
remote source into memory and return this dataset.
451451
452452
Normally, it should not be necessary to call this method in user code,
453453
because all xarray functions should either work on deferred data or
454454
load data automatically. However, this method can be necessary when
455455
working with many file objects on disk.
456+
457+
Parameters
458+
----------
459+
**kwargs : dict
460+
Additional keyword arguments passed on to ``dask.array.compute``.
461+
462+
See Also
463+
--------
464+
dask.array.compute
456465
"""
457466
# access .data to coerce everything to numpy or dask arrays
458467
lazy_data = {k: v._data for k, v in self.variables.items()
@@ -461,7 +470,7 @@ def load(self):
461470
import dask.array as da
462471

463472
# evaluate all the dask arrays simultaneously
464-
evaluated_data = da.compute(*lazy_data.values())
473+
evaluated_data = da.compute(*lazy_data.values(), **kwargs)
465474

466475
for k, data in zip(lazy_data, evaluated_data):
467476
self.variables[k].data = data
@@ -473,7 +482,7 @@ def load(self):
473482

474483
return self
475484

476-
def compute(self):
485+
def compute(self, **kwargs):
477486
"""Manually trigger loading of this dataset's data from disk or a
478487
remote source into memory and return a new dataset. The original is
479488
left unaltered.
@@ -482,11 +491,20 @@ def compute(self):
482491
because all xarray functions should either work on deferred data or
483492
load data automatically. However, this method can be necessary when
484493
working with many file objects on disk.
494+
495+
Parameters
496+
----------
497+
**kwargs : dict
498+
Additional keyword arguments passed on to ``dask.array.compute``.
499+
500+
See Also
501+
--------
502+
dask.array.compute
485503
"""
486504
new = self.copy(deep=False)
487-
return new.load()
505+
return new.load(**kwargs)
488506

489-
def _persist_inplace(self):
507+
def _persist_inplace(self, **kwargs):
490508
""" Persist all Dask arrays in memory """
491509
# access .data to coerce everything to numpy or dask arrays
492510
lazy_data = {k: v._data for k, v in self.variables.items()
@@ -495,24 +513,33 @@ def _persist_inplace(self):
495513
import dask
496514

497515
# evaluate all the dask arrays simultaneously
498-
evaluated_data = dask.persist(*lazy_data.values())
516+
evaluated_data = dask.persist(*lazy_data.values(), **kwargs)
499517

500518
for k, data in zip(lazy_data, evaluated_data):
501519
self.variables[k].data = data
502520

503521
return self
504522

505-
def persist(self):
523+
def persist(self, **kwargs):
506524
""" Trigger computation, keeping data as dask arrays
507525
508526
This operation can be used to trigger computation on underlying dask
509527
arrays, similar to ``.compute()``. However this operation keeps the
510528
data as dask arrays. This is particularly useful when using the
511529
dask.distributed scheduler and you want to load a large amount of data
512530
into distributed memory.
531+
532+
Parameters
533+
----------
534+
**kwargs : dict
535+
Additional keyword arguments passed on to ``dask.persist``.
536+
537+
See Also
538+
--------
539+
dask.persist
513540
"""
514541
new = self.copy(deep=False)
515-
return new._persist_inplace()
542+
return new._persist_inplace(**kwargs)
516543

517544
@classmethod
518545
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,

xarray/core/variable.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,29 +307,49 @@ def data(self, data):
307307
def _indexable_data(self):
308308
return orthogonally_indexable(self._data)
309309

310-
def load(self):
310+
def load(self, **kwargs):
311311
"""Manually trigger loading of this variable's data from disk or a
312312
remote source into memory and return this variable.
313313
314314
Normally, it should not be necessary to call this method in user code,
315315
because all xarray functions should either work on deferred data or
316316
load data automatically.
317+
318+
Parameters
319+
----------
320+
**kwargs : dict
321+
Additional keyword arguments passed on to ``dask.array.compute``.
322+
323+
See Also
324+
--------
325+
dask.array.compute
317326
"""
318-
if not isinstance(self._data, np.ndarray):
327+
if isinstance(self._data, dask_array_type):
328+
self._data = as_compatible_data(self._data.compute(**kwargs))
329+
elif not isinstance(self._data, np.ndarray):
319330
self._data = np.asarray(self._data)
320331
return self
321332

322-
def compute(self):
333+
def compute(self, **kwargs):
323334
"""Manually trigger loading of this variable's data from disk or a
324335
remote source into memory and return a new variable. The original is
325336
left unaltered.
326337
327338
Normally, it should not be necessary to call this method in user code,
328339
because all xarray functions should either work on deferred data or
329340
load data automatically.
341+
342+
Parameters
343+
----------
344+
**kwargs : dict
345+
Additional keyword arguments passed on to ``dask.array.compute``.
346+
347+
See Also
348+
--------
349+
dask.array.compute
330350
"""
331351
new = self.copy(deep=False)
332-
return new.load()
352+
return new.load(**kwargs)
333353

334354
@property
335355
def values(self):

xarray/tests/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
except ImportError:
2121
import unittest
2222

23+
try:
24+
from unittest import mock
25+
except ImportError:
26+
import mock
27+
2328
try:
2429
import scipy
2530
has_scipy = True

xarray/tests/test_dask.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import pickle
55
import numpy as np
66
import pandas as pd
7+
import pytest
78

89
import xarray as xr
910
from xarray import Variable, DataArray, Dataset
1011
import xarray.ufuncs as xu
1112
from xarray.core.pycompat import suppress
1213
from . import TestCase, requires_dask
1314

14-
from xarray.tests import unittest
15+
from xarray.tests import unittest, mock
1516

1617
with suppress(ImportError):
1718
import dask
@@ -394,6 +395,47 @@ def test_from_dask_variable(self):
394395
self.assertLazyAndIdentical(self.lazy_array, a)
395396

396397

398+
@requires_dask
399+
@pytest.mark.parametrize("method", ['load', 'compute'])
400+
def test_dask_kwargs_variable(method):
401+
x = Variable('y', da.from_array(np.arange(3), chunks=(2,)))
402+
# args should be passed on to da.Array.compute()
403+
with mock.patch.object(da.Array, 'compute',
404+
return_value=np.arange(3)) as mock_compute:
405+
getattr(x, method)(foo='bar')
406+
mock_compute.assert_called_with(foo='bar')
407+
408+
409+
@requires_dask
410+
@pytest.mark.parametrize("method", ['load', 'compute', 'persist'])
411+
def test_dask_kwargs_dataarray(method):
412+
data = da.from_array(np.arange(3), chunks=(2,))
413+
x = DataArray(data)
414+
if method in ['load', 'compute']:
415+
dask_func = 'dask.array.compute'
416+
else:
417+
dask_func = 'dask.persist'
418+
# args should be passed on to "dask_func"
419+
with mock.patch(dask_func) as mock_func:
420+
getattr(x, method)(foo='bar')
421+
mock_func.assert_called_with(data, foo='bar')
422+
423+
424+
@requires_dask
425+
@pytest.mark.parametrize("method", ['load', 'compute', 'persist'])
426+
def test_dask_kwargs_dataset(method):
427+
data = da.from_array(np.arange(3), chunks=(2,))
428+
x = Dataset({'x': (('y'), data)})
429+
if method in ['load', 'compute']:
430+
dask_func = 'dask.array.compute'
431+
else:
432+
dask_func = 'dask.persist'
433+
# args should be passed on to "dask_func"
434+
with mock.patch(dask_func) as mock_func:
435+
getattr(x, method)(foo='bar')
436+
mock_func.assert_called_with(data, foo='bar')
437+
438+
397439
kernel_call_count = 0
398440
def kernel():
399441
"""Dask kernel to test pickling/unpickling.
@@ -403,6 +445,7 @@ def kernel():
403445
kernel_call_count += 1
404446
return np.ones(1)
405447

448+
406449
def build_dask_array():
407450
global kernel_call_count
408451
kernel_call_count = 0

0 commit comments

Comments
 (0)