Skip to content

Commit b4e3cbc

Browse files
kmuehlbauerscottchapre-commit-ci[bot]Illviljandcherian
authored
Fill missing data_vars during concat by reindexing (#7400)
* Fill missing data variables during concat by reindexing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * FIX: use `Any` for type of `fill_value` as this seems consistent with other places * ENH: add tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typing Co-authored-by: Illviljan <[email protected]> * typing Co-authored-by: Illviljan <[email protected]> * typing Co-authored-by: Illviljan <[email protected]> * use None instead of False Co-authored-by: Illviljan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * concatenate variable in any case if variable has concat_dim * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add tests from @scottcha #3545 * typing * fix typing * fix tests with, finalize typing * add whats-new.rst entry * Update xarray/tests/test_concat.py Co-authored-by: Illviljan <[email protected]> * Update xarray/tests/test_concat.py Co-authored-by: Illviljan <[email protected]> * add TODO, fix numpy.random.default_rng * change np.random to use Generator * move code for variable order into dedicated function, merge with _parse_datasets, provide fast lane for variable order estimation * fix comment * Use order from first dataset, append missing variables to the end * ensure fill_value is dict * ensure fill_value in align * simplify combined_var, fix test * revert fill_value for alignment.py * derive variable order in order of appearance as suggested per review * remove unneeded enumerate * Use alignment.reindex_variables instead. This also removes the need to handle fill_value * small cleanup * Update doc/whats-new.rst Co-authored-by: Deepak Cherian <[email protected]> * adapt tests as per review request, fix ensure_common_dims * adapt tests as per review request * fix whats-new.rst * add whats-new.rst entry * Add additional test with scalar data_var * remove erroneous content from whats-new.rst Co-authored-by: Scott Chamberlin <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Illviljan <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent b21f62e commit b4e3cbc

File tree

3 files changed

+466
-24
lines changed

3 files changed

+466
-24
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ Deprecations
3535
Bug fixes
3636
~~~~~~~~~
3737

38+
- :py:func:`xarray.concat` can now concatenate variables present in some datasets but
39+
not others (:issue:`508`, :pull:`7400`).
40+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_ and `Scott Chamberlin <https://github.com/scottcha>`_.
3841

3942
Documentation
4043
~~~~~~~~~~~~~

xarray/core/concat.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66

77
from xarray.core import dtypes, utils
8-
from xarray.core.alignment import align
8+
from xarray.core.alignment import align, reindex_variables
99
from xarray.core.duck_array_ops import lazy_array_equiv
1010
from xarray.core.indexes import Index, PandasIndex
1111
from xarray.core.merge import (
@@ -378,7 +378,9 @@ def process_subset_opt(opt, subset):
378378

379379
elif opt == "all":
380380
concat_over.update(
381-
set(getattr(datasets[0], subset)) - set(datasets[0].dims)
381+
set().union(
382+
*list(set(getattr(d, subset)) - set(d.dims) for d in datasets)
383+
)
382384
)
383385
elif opt == "minimal":
384386
pass
@@ -406,19 +408,26 @@ def process_subset_opt(opt, subset):
406408

407409
# determine dimensional coordinate names and a dict mapping name to DataArray
408410
def _parse_datasets(
409-
datasets: Iterable[T_Dataset],
410-
) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]:
411-
411+
datasets: list[T_Dataset],
412+
) -> tuple[
413+
dict[Hashable, Variable],
414+
dict[Hashable, int],
415+
set[Hashable],
416+
set[Hashable],
417+
list[Hashable],
418+
]:
412419
dims: set[Hashable] = set()
413420
all_coord_names: set[Hashable] = set()
414421
data_vars: set[Hashable] = set() # list of data_vars
415422
dim_coords: dict[Hashable, Variable] = {} # maps dim name to variable
416423
dims_sizes: dict[Hashable, int] = {} # shared dimension sizes to expand variables
424+
variables_order: dict[Hashable, Variable] = {} # variables in order of appearance
417425

418426
for ds in datasets:
419427
dims_sizes.update(ds.dims)
420428
all_coord_names.update(ds.coords)
421429
data_vars.update(ds.data_vars)
430+
variables_order.update(ds.variables)
422431

423432
# preserves ordering of dimensions
424433
for dim in ds.dims:
@@ -429,7 +438,7 @@ def _parse_datasets(
429438
dim_coords[dim] = ds.coords[dim].variable
430439
dims = dims | set(ds.dims)
431440

432-
return dim_coords, dims_sizes, all_coord_names, data_vars
441+
return dim_coords, dims_sizes, all_coord_names, data_vars, list(variables_order)
433442

434443

435444
def _dataset_concat(
@@ -439,7 +448,7 @@ def _dataset_concat(
439448
coords: str | list[str],
440449
compat: CompatOptions,
441450
positions: Iterable[Iterable[int]] | None,
442-
fill_value: object = dtypes.NA,
451+
fill_value: Any = dtypes.NA,
443452
join: JoinOptions = "outer",
444453
combine_attrs: CombineAttrsOptions = "override",
445454
) -> T_Dataset:
@@ -471,7 +480,9 @@ def _dataset_concat(
471480
align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value)
472481
)
473482

474-
dim_coords, dims_sizes, coord_names, data_names = _parse_datasets(datasets)
483+
dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets(
484+
datasets
485+
)
475486
dim_names = set(dim_coords)
476487
unlabeled_dims = dim_names - coord_names
477488

@@ -525,7 +536,7 @@ def _dataset_concat(
525536

526537
# we've already verified everything is consistent; now, calculate
527538
# shared dimension sizes so we can expand the necessary variables
528-
def ensure_common_dims(vars):
539+
def ensure_common_dims(vars, concat_dim_lengths):
529540
# ensure each variable with the given name shares the same
530541
# dimensions and the same shape for all of them except along the
531542
# concat dimension
@@ -553,16 +564,35 @@ def get_indexes(name):
553564
data = var.set_dims(dim).values
554565
yield PandasIndex(data, dim, coord_dtype=var.dtype)
555566

567+
# create concatenation index, needed for later reindexing
568+
concat_index = list(range(sum(concat_dim_lengths)))
569+
556570
# stack up each variable and/or index to fill-out the dataset (in order)
557571
# n.b. this loop preserves variable order, needed for groupby.
558-
for name in datasets[0].variables:
572+
for name in vars_order:
559573
if name in concat_over and name not in result_indexes:
560-
try:
561-
vars = ensure_common_dims([ds[name].variable for ds in datasets])
562-
except KeyError:
563-
raise ValueError(f"{name!r} is not present in all datasets.")
564-
565-
# Try concatenate the indexes, concatenate the variables when no index
574+
variables = []
575+
variable_index = []
576+
var_concat_dim_length = []
577+
for i, ds in enumerate(datasets):
578+
if name in ds.variables:
579+
variables.append(ds[name].variable)
580+
# add to variable index, needed for reindexing
581+
var_idx = [
582+
sum(concat_dim_lengths[:i]) + k
583+
for k in range(concat_dim_lengths[i])
584+
]
585+
variable_index.extend(var_idx)
586+
var_concat_dim_length.append(len(var_idx))
587+
else:
588+
# raise if coordinate not in all datasets
589+
if name in coord_names:
590+
raise ValueError(
591+
f"coordinate {name!r} not present in all datasets."
592+
)
593+
vars = ensure_common_dims(variables, var_concat_dim_length)
594+
595+
# Try to concatenate the indexes, concatenate the variables when no index
566596
# is found on all datasets.
567597
indexes: list[Index] = list(get_indexes(name))
568598
if indexes:
@@ -589,6 +619,15 @@ def get_indexes(name):
589619
combined_var = concat_vars(
590620
vars, dim, positions, combine_attrs=combine_attrs
591621
)
622+
# reindex if variable is not present in all datasets
623+
if len(variable_index) < len(concat_index):
624+
combined_var = reindex_variables(
625+
variables={name: combined_var},
626+
dim_pos_indexers={
627+
dim: pd.Index(variable_index).get_indexer(concat_index)
628+
},
629+
fill_value=fill_value,
630+
)[name]
592631
result_vars[name] = combined_var
593632

594633
elif name in result_vars:

0 commit comments

Comments
 (0)