Description
What is your issue?
In the datatree call today we narrowed down an issue with how datatree maps methods over many variables in many nodes. This issue is essentially xarray-contrib/datatree#67, but I'll attempt to discuss the problem and solution in more general terms.
Context in xarray
xarray.Dataset
is essentially a mapping of variable names to Variable
objects, and most Dataset
methods implicitly map a method defined on Variable over all these variables (e.g. .mean()
). Sometimes the mapped method can be naively applied to every variable in the dataset, but sometimes it doesn't make sense to apply it to some of the variables. For example .mean(dim='time')
only makes sense for the variables in the dataset that actually have a time
dimension.
xarray.Dataset
handles this for the user by either working out what version of the method does make sense for that variable (e.g. only trying to take the mean along the reduction dimensions actually present on that variable), or just passing the variable through unaltered. There are some weird subtleties lurking here, e.g. with statistical reductions like std
and var
.
Line 6853 in 239309f
There is therefore a difference between
ds.map(Variable.{REDUCTION}, dim='time')
and ds.{REDUCTION}(dim='time')
For example:
In [13]: ds = xr.Dataset({'a': ('x', [1, 2]), 'b': 0})
In [14]: ds.isel(x=0)
Out[14]:
<xarray.Dataset> Size: 16B
Dimensions: ()
Data variables:
a int64 8B 1
b int64 8B 0
In [15]: ds.map(Variable.isel, x=0)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[15], line 1
----> 1 ds.map(Variable.isel, x=0)
...
ValueError: Dimensions {'x'} do not exist. Expected one or more of ()
(Aside: It would be nice for Dataset.map
to include information about which variable it raised an exception on in the error message.)
Clearly Dataset.isel
does more than just applying Variable.isel
using Dataset.map
.
Issue in DataTree
In datatree we have to map methods over different variables in the same node, but also over different variables in different nodes. Currently the implementation of a method naively maps the Dataset
method over every node using map_over_subtree
, but if there is a node containing a variable for which the method args are invalid, it will raise an exception.
This causes problems for users, for example in xarray-contrib/datatree#67. A minimal example of this problem would be
In [18]: ds1 = xr.Dataset({'a': ('x', [1, 2])})
In [19]: ds2 = xr.Dataset({'b': 0})
In [20]: dt = DataTree.from_dict({'node1': ds1, 'node2': ds2})
In [21]: dt
Out[21]:
DataTree('None', parent=None)
├── DataTree('node1')
│ Dimensions: (x: 2)
│ Dimensions without coordinates: x
│ Data variables:
│ a (x) int64 16B 1 2
└── DataTree('node2')
Dimensions: ()
Data variables:
b int64 8B 0
In [22]: dt.isel(x=0)
ValueError: Dimensions {'x'} do not exist. Expected one or more of FrozenMappingWarningOnValuesAccess({})
Raised whilst mapping function over node with path /node2
(The slightly weird error message here is related to the deprecation cycle in #8500)
We would have preferred that variable b
in node2
survived unchanged, like it does in the pure Dataset
example.
Desired behaviour
We can kind of think of the desired behaviour like a hypothesis property we want (xref #1846), but not quite. It would be something like
dt.{REDUCTION}().flatten_into_dataset() == dt.flatten_into_dataset().{REDUCTION}()
except that .flatten_into_dataset()
can't really exist for all cases otherwise we wouldn't need datatree.
Proposed Solution
There are two ways I can imagine implementing this.
- Use
map_over_subtree
the apply the method as-is and try to catch known possibleKeyErrors
for missing dimensions. This would be fragile. - Do some kind of pre-checking of the data in the tree, potentially adjust the method before applying it using
map_over_subtree
.
I think @shoyer and I concluded that we should make (2), in the form of some kind of new primitive, i.e. DataTree.reduce
. (Actually DataTree.reduce
already exists, but should be changed to not just map_over_subtree
Dataset.reduce
). Taking after Dataset.reduce
, it would look something like this:
class DataTree:
def reduce(self, reduce_func: Callable, dim: Dims = None, *, **kwargs) -> DataTree:
all_dims_in_tree = set(node.dims for node in self.subtree)
missing_dims = tuple(d for d in dims if d not in all_dims_in_tree)
if missing_dims:
raise ValueError()
# TODO this could probably be refactored to call `map_over_subtree`
for node in self.subtree:
# using only the reduction dims that are actually present here would fix datatree GH issue #67
reduce_dims = [d for d in node.dims if d in dims]
result = node.ds.reduce(func, dims=reduce_dims, **kwargs)
# TODO build the result and return it
Then every method that has this pattern of acting over one or more dims should be mapped over the tree using DataTree.reduce
, not map_over_subtree
.
Metadata
Metadata
Assignees
Type
Projects
Status