Skip to content

Commit b190501

Browse files
dopplershiftshoyer
authored andcommitted
ENH: Support using opened netCDF4.Dataset (Fixes #1459) (#1508)
Make the filename argument to NetCDF4DataStore polymorphic so that a Dataset can be passed in.
1 parent 174bad0 commit b190501

File tree

4 files changed

+62
-22
lines changed

4 files changed

+62
-22
lines changed

doc/whats-new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ Enhancements
5555
(:issue:`576`).
5656
By `Stephan Hoyer <https://github.com/shoyer>`_.
5757

58+
- Support using an existing, opened netCDF4 ``Dataset`` with
59+
:py:class:`~xarray.backends.NetCDF4DataStore`. This permits creating an
60+
:py:class:`~xarray.Dataset` from a netCDF4 ``Dataset`` that has been opened using
61+
other means (:issue:`1459`).
62+
By `Ryan May <https://github.com/dopplershift>`_.
63+
5864
Bug fixes
5965
~~~~~~~~~
6066

xarray/backends/api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,9 @@ def maybe_decode_store(store, lock=False):
278278
engine = _get_default_engine(filename_or_obj,
279279
allow_remote=True)
280280
if engine == 'netcdf4':
281-
store = backends.NetCDF4DataStore(filename_or_obj, group=group,
282-
autoclose=autoclose)
281+
store = backends.NetCDF4DataStore.open(filename_or_obj,
282+
group=group,
283+
autoclose=autoclose)
283284
elif engine == 'scipy':
284285
store = backends.ScipyDataStore(filename_or_obj,
285286
autoclose=autoclose)
@@ -518,7 +519,7 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT,
518519
return combined
519520

520521

521-
WRITEABLE_STORES = {'netcdf4': backends.NetCDF4DataStore,
522+
WRITEABLE_STORES = {'netcdf4': backends.NetCDF4DataStore.open,
522523
'scipy': backends.ScipyDataStore,
523524
'h5netcdf': backends.H5NetCDFStore}
524525

@@ -553,7 +554,7 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
553554
_validate_attrs(dataset)
554555

555556
try:
556-
store_cls = WRITEABLE_STORES[engine]
557+
store_open = WRITEABLE_STORES[engine]
557558
except KeyError:
558559
raise ValueError('unrecognized engine for to_netcdf: %r' % engine)
559560

@@ -564,7 +565,7 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
564565
sync = writer is None
565566

566567
target = path_or_file if path_or_file is not None else BytesIO()
567-
store = store_cls(target, mode, format, group, writer)
568+
store = store_open(target, mode, format, group, writer)
568569

569570
if unlimited_dims is None:
570571
unlimited_dims = dataset.encoding.get('unlimited_dims', None)

xarray/backends/netCDF4_.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -187,35 +187,56 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs):
187187
with close_on_error(ds):
188188
ds = _nc4_group(ds, group, mode)
189189

190+
_disable_mask_and_scale(ds)
191+
192+
return ds
193+
194+
195+
def _disable_mask_and_scale(ds):
190196
for var in ds.variables.values():
191197
# we handle masking and scaling ourselves
192198
var.set_auto_maskandscale(False)
193-
return ds
194199

195200

196201
class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin):
197202
"""Store for reading and writing data via the Python-NetCDF4 library.
198203
199204
This store supports NetCDF3, NetCDF4 and OpenDAP datasets.
200205
"""
201-
def __init__(self, filename, mode='r', format='NETCDF4', group=None,
202-
writer=None, clobber=True, diskless=False, persist=False,
206+
def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None,
203207
autoclose=False):
208+
209+
if autoclose and opener is None:
210+
raise ValueError('autoclose requires an opener')
211+
212+
_disable_mask_and_scale(netcdf4_dataset)
213+
214+
self.ds = netcdf4_dataset
215+
self._autoclose = autoclose
216+
self._isopen = True
217+
self.format = self.ds.data_model
218+
self._filename = self.ds.filepath()
219+
self.is_remote = is_remote_uri(self._filename)
220+
self._mode = mode = 'a' if mode == 'w' else mode
221+
if opener:
222+
self._opener = functools.partial(opener, mode=self._mode)
223+
else:
224+
self._opener = opener
225+
super(NetCDF4DataStore, self).__init__(writer)
226+
227+
@classmethod
228+
def open(cls, filename, mode='r', format='NETCDF4', group=None,
229+
writer=None, clobber=True, diskless=False, persist=False,
230+
autoclose=False):
204231
if format is None:
205232
format = 'NETCDF4'
206233
opener = functools.partial(_open_netcdf4_group, filename, mode=mode,
207234
group=group, clobber=clobber,
208235
diskless=diskless, persist=persist,
209236
format=format)
210-
self.ds = opener()
211-
self._autoclose = autoclose
212-
self._isopen = True
213-
self.format = format
214-
self.is_remote = is_remote_uri(filename)
215-
self._filename = filename
216-
self._mode = 'a' if mode == 'w' else mode
217-
self._opener = functools.partial(opener, mode=self._mode)
218-
super(NetCDF4DataStore, self).__init__(writer)
237+
ds = opener()
238+
return cls(ds, mode=mode, writer=writer, opener=opener,
239+
autoclose=autoclose)
219240

220241
def open_store_variable(self, name, var):
221242
with self.ensure_open(autoclose=False):

xarray/tests/test_backends.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,18 @@ def test_0dimensional_variable(self):
762762
expected = Dataset({'x': ((), 123)})
763763
self.assertDatasetIdentical(expected, ds)
764764

765+
def test_already_open_dataset(self):
766+
with create_tmp_file() as tmp_file:
767+
with nc4.Dataset(tmp_file, mode='w') as nc:
768+
v = nc.createVariable('x', 'int')
769+
v[...] = 42
770+
771+
nc = nc4.Dataset(tmp_file, mode='r')
772+
with backends.NetCDF4DataStore(nc, autoclose=False) as store:
773+
with open_dataset(store) as ds:
774+
expected = Dataset({'x': ((), 42)})
775+
self.assertDatasetIdentical(expected, ds)
776+
765777
def test_variable_len_strings(self):
766778
with create_tmp_file() as tmp_file:
767779
values = np.array(['foo', 'bar', 'baz'], dtype=object)
@@ -784,7 +796,7 @@ class NetCDF4DataTest(BaseNetCDF4Test, TestCase):
784796
@contextlib.contextmanager
785797
def create_store(self):
786798
with create_tmp_file() as tmp_file:
787-
with backends.NetCDF4DataStore(tmp_file, mode='w') as store:
799+
with backends.NetCDF4DataStore.open(tmp_file, mode='w') as store:
788800
yield store
789801

790802
@contextlib.contextmanager
@@ -972,8 +984,8 @@ class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase):
972984
@contextlib.contextmanager
973985
def create_store(self):
974986
with create_tmp_file() as tmp_file:
975-
with backends.NetCDF4DataStore(tmp_file, mode='w',
976-
format='NETCDF3_CLASSIC') as store:
987+
with backends.NetCDF4DataStore.open(
988+
tmp_file, mode='w', format='NETCDF3_CLASSIC') as store:
977989
yield store
978990

979991
@contextlib.contextmanager
@@ -998,8 +1010,8 @@ class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes,
9981010
@contextlib.contextmanager
9991011
def create_store(self):
10001012
with create_tmp_file() as tmp_file:
1001-
with backends.NetCDF4DataStore(tmp_file, mode='w',
1002-
format='NETCDF4_CLASSIC') as store:
1013+
with backends.NetCDF4DataStore.open(
1014+
tmp_file, mode='w', format='NETCDF4_CLASSIC') as store:
10031015
yield store
10041016

10051017
@contextlib.contextmanager

0 commit comments

Comments
 (0)