diff --git a/docs/whats-new.rst b/docs/whats-new.rst index a6381bbd..e05cf824 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -10,6 +10,8 @@ What's new By `Justus Magin `_. - preserve :py:class:`pandas.MultiIndex` objects (:issue:`164`, :pull:`168`). By `Justus Magin `_. +- fix "quantifying" dimension coordinates (:issue:`105`, :pull:`174`). + By `Justus Magin `_. 0.2.1 (26 Jul 2021) ------------------- diff --git a/pint_xarray/accessors.py b/pint_xarray/accessors.py index 4ebe3241..7df54665 100644 --- a/pint_xarray/accessors.py +++ b/pint_xarray/accessors.py @@ -142,7 +142,10 @@ def _decide_units(units, registry, unit_attribute): elif units is _default: if unit_attribute in no_unit_values: return unit_attribute - units = registry.parse_units(unit_attribute) + if isinstance(unit_attribute, Unit): + units = unit_attribute + else: + units = registry.parse_units(unit_attribute) else: units = registry.parse_units(units) return units @@ -360,6 +363,31 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs): if invalid_units: raise ValueError(format_error_message(invalid_units, "parse")) + existing_units = { + name: unit + for name, unit in conversion.extract_units(self.da).items() + if isinstance(unit, Unit) + } + overwritten_units = { + name: (old, new) + for name, (old, new) in zip_mappings( + existing_units, new_units, fill_value=_default + ).items() + if old is not _default and new is not _default + } + if overwritten_units: + errors = { + name: ( + new, + ValueError( + f"Cannot attach unit {repr(new)} to quantity: data " + f"already has units {repr(old)}" + ), + ) + for name, (old, new) in overwritten_units.items() + } + raise ValueError(format_error_message(errors, "attach")) + return self.da.pipe(conversion.strip_unit_attributes).pipe( conversion.attach_units, new_units ) @@ -1050,6 +1078,31 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs): if invalid_units: raise ValueError(format_error_message(invalid_units, "parse")) + existing_units = { + name: unit + for name, unit in conversion.extract_units(self.ds).items() + if isinstance(unit, Unit) + } + overwritten_units = { + name: (old, new) + for name, (old, new) in zip_mappings( + existing_units, new_units, fill_value=_default + ).items() + if old is not _default and new is not _default + } + if overwritten_units: + errors = { + name: ( + new, + ValueError( + f"Cannot attach unit {repr(new)} to quantity: data " + f"already has units {repr(old)}" + ), + ) + for name, (old, new) in overwritten_units.items() + } + raise ValueError(format_error_message(errors, "attach")) + return self.ds.pipe(conversion.strip_unit_attributes).pipe( conversion.attach_units, new_units ) diff --git a/pint_xarray/tests/test_accessors.py b/pint_xarray/tests/test_accessors.py index a457ed09..91e04517 100644 --- a/pint_xarray/tests/test_accessors.py +++ b/pint_xarray/tests/test_accessors.py @@ -135,6 +135,22 @@ def test_parse_integer_inverse(self): result = da.pint.quantify() assert result.pint.units == Unit("1 / meter") + def test_dimension_coordinate(self): + ds = xr.Dataset(coords={"x": ("x", [10], {"units": "m"})}) + arr = ds.x + + # does not actually quantify because `arr` wraps a IndexVariable + # but we still get a `Unit` in the attrs + q = arr.pint.quantify() + assert isinstance(q.attrs["units"], Unit) + + def test_dimension_coordinate_already_quantified(self): + ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})}) + arr = ds.x + + with pytest.raises(ValueError): + arr.pint.quantify({"x": "s"}) + @pytest.mark.parametrize("formatter", ("", "P", "C")) @pytest.mark.parametrize("modifier", ("", "~")) @@ -313,6 +329,20 @@ def test_error_indicates_problematic_variable(self, example_unitless_ds): with pytest.raises(ValueError, match="'users'"): ds.pint.quantify(units={"users": "aecjhbav"}) + def test_existing_units(self, example_quantity_ds): + ds = example_quantity_ds.copy() + ds.t.attrs["units"] = unit_registry.Unit("m") + + with pytest.raises(ValueError, match="Cannot attach"): + ds.pint.quantify({"funds": "kg"}) + + def test_existing_units_dimension(self, example_quantity_ds): + ds = example_quantity_ds.copy() + ds.t.attrs["units"] = unit_registry.Unit("m") + + with pytest.raises(ValueError, match="Cannot attach"): + ds.pint.quantify({"t": "s"}) + class TestDequantifyDataSet: def test_strip_units(self, example_quantity_ds):