Skip to content

use .pint.quantify() as a identity operator #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ What's new
By `Justus Magin <https://github.com/keewis>`_.
- fix "quantifying" dimension coordinates (:issue:`105`, :pull:`174`).
By `Justus Magin <https://github.com/keewis>`_.
- allow using :py:meth:`DataArray.pint.quantify` and :py:meth:`Dataset.pint.quantify`
as identity operators (:issue:`47`, :pull:`175`).
By `Justus Magin <https://github.com/keewis>`_.

0.2.1 (26 Jul 2021)
-------------------
Expand Down
35 changes: 9 additions & 26 deletions pint_xarray/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools

import pint
from pint import Quantity, Unit
from pint import Unit
from xarray import register_dataarray_accessor, register_dataset_accessor
from xarray.core.dtypes import NA

Expand Down Expand Up @@ -71,16 +71,6 @@ def zip_mappings(*mappings, fill_value=None):
return zipped


def merge_mappings(first, *mappings):
result = first.copy()
for mapping in mappings:
result.update(
{key: value for key, value in mapping.items() if value is not None}
)

return result


def units_to_str_or_none(mapping, unit_format):
formatter = str if not unit_format else lambda v: unit_format.format(v)

Expand Down Expand Up @@ -109,8 +99,8 @@ def either_dict_or_kwargs(positional, keywords, method_name):


def get_registry(unit_registry, new_units, existing_units):
units = merge_mappings(existing_units, new_units)
registries = {unit._REGISTRY for unit in units.values() if isinstance(unit, Unit)}
units = itertools.chain(new_units.values(), existing_units.values())
registries = {unit._REGISTRY for unit in units if isinstance(unit, Unit)}

if unit_registry is None:
if not registries:
Expand All @@ -133,7 +123,7 @@ def get_registry(unit_registry, new_units, existing_units):


def _decide_units(units, registry, unit_attribute):
if units is _default and unit_attribute is _default:
if units is _default and unit_attribute in (None, _default):
# or warn and return None?
raise ValueError("no units given")
elif units in no_unit_values or isinstance(units, Unit):
Expand Down Expand Up @@ -321,13 +311,6 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
array([0.4, 0.9])
Dimensions without coordinates: wavelength
"""

if isinstance(self.da.data, Quantity):
raise ValueError(
f"Cannot attach unit {units} to quantity: data "
f"already has units {self.da.data.units}"
)

if units is None or isinstance(units, (str, pint.Unit)):
if self.da.name in unit_kwargs:
raise ValueError(
Expand All @@ -347,11 +330,11 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
new_units = {}
invalid_units = {}
for name, (unit, attr) in possible_new_units.items():
if unit is not _default or attr is not _default:
if unit not in (_default, None) or attr not in (_default, None):
try:
new_units[name] = _decide_units(unit, registry, attr)
except (ValueError, pint.UndefinedUnitError) as e:
if unit is not _default:
if unit not in (_default, None):
type = "parameter"
reported_unit = unit
else:
Expand All @@ -373,7 +356,7 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
for name, (old, new) in zip_mappings(
existing_units, new_units, fill_value=_default
).items()
if old is not _default and new is not _default
if old is not _default and new is not _default and old != new
}
if overwritten_units:
errors = {
Expand Down Expand Up @@ -1062,7 +1045,7 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
new_units = {}
invalid_units = {}
for name, (unit, attr) in possible_new_units.items():
if unit is not _default or attr is not _default:
if unit is not _default or attr not in (None, _default):
try:
new_units[name] = _decide_units(unit, registry, attr)
except (ValueError, pint.UndefinedUnitError) as e:
Expand All @@ -1088,7 +1071,7 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
for name, (old, new) in zip_mappings(
existing_units, new_units, fill_value=_default
).items()
if old is not _default and new is not _default
if old is not _default and new is not _default and old != new
}
if overwritten_units:
errors = {
Expand Down
3 changes: 3 additions & 0 deletions pint_xarray/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def array_attach_units(data, unit):
raise ValueError(f"cannot use {unit!r} as a unit")

if isinstance(data, pint.Quantity):
if data.units == unit:
return data

raise ValueError(
f"Cannot attach unit {unit!r} to quantity: data "
f"already has units {data.units}"
Expand Down
106 changes: 84 additions & 22 deletions pint_xarray/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
assert_equal,
assert_identical,
assert_units_equal,
raises_regex,
requires_bottleneck,
requires_dask_array,
requires_scipy,
Expand Down Expand Up @@ -111,10 +110,62 @@ def test_override_units(self, example_unitless_da, no_unit_value):
with pytest.raises(AttributeError):
result["u"].data.units

def test_error_when_already_units(self, example_quantity_da):
def test_error_when_changing_units(self, example_quantity_da):
da = example_quantity_da
with raises_regex(ValueError, "already has units"):
da.pint.quantify()
with pytest.raises(ValueError, match="already has units"):
da.pint.quantify("s")

def test_attach_no_units(self):
arr = xr.DataArray([1, 2, 3], dims="x")
quantified = arr.pint.quantify()
assert_identical(quantified, arr)
assert_units_equal(quantified, arr)

def test_attach_no_new_units(self):
da = xr.DataArray(unit_registry.Quantity([1, 2, 3], "m"), dims="x")
quantified = da.pint.quantify()
assert_identical(quantified, da)
assert_units_equal(quantified, da)

def test_attach_same_units(self):
da = xr.DataArray(unit_registry.Quantity([1, 2, 3], "m"), dims="x")
quantified = da.pint.quantify("m")
assert_identical(quantified, da)
assert_units_equal(quantified, da)

def test_error_when_changing_units_dimension_coordinates(self):
arr = xr.DataArray(
[1, 2, 3],
dims="x",
coords={"x": ("x", [-1, 0, 1], {"units": unit_registry.Unit("m")})},
)
with pytest.raises(ValueError, match="already has units"):
arr.pint.quantify({"x": "s"})

def test_dimension_coordinate_array(self):
ds = xr.Dataset(coords={"x": ("x", [10], {"units": "m"})})
arr = ds.x

# does not actually quantify because `arr` wraps a IndexVariable
# but we still get a `Unit` in the attrs
q = arr.pint.quantify()
assert isinstance(q.attrs["units"], Unit)

def test_dimension_coordinate_array_already_quantified(self):
ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})})
arr = ds.x

with pytest.raises(ValueError):
arr.pint.quantify({"x": "s"})

def test_dimension_coordinate_array_already_quantified_same_units(self):
ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})})
arr = ds.x

quantified = arr.pint.quantify({"x": "m"})

assert_identical(quantified, arr)
assert_units_equal(quantified, arr)

def test_error_on_nonsense_units(self, example_unitless_da):
da = example_unitless_da
Expand All @@ -135,22 +186,6 @@ def test_parse_integer_inverse(self):
result = da.pint.quantify()
assert result.pint.units == Unit("1 / meter")

def test_dimension_coordinate(self):
ds = xr.Dataset(coords={"x": ("x", [10], {"units": "m"})})
arr = ds.x

# does not actually quantify because `arr` wraps a IndexVariable
# but we still get a `Unit` in the attrs
q = arr.pint.quantify()
assert isinstance(q.attrs["units"], Unit)

def test_dimension_coordinate_already_quantified(self):
ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})})
arr = ds.x

with pytest.raises(ValueError):
arr.pint.quantify({"x": "s"})


@pytest.mark.parametrize("formatter", ("", "P", "C"))
@pytest.mark.parametrize("modifier", ("", "~"))
Expand Down Expand Up @@ -308,8 +343,35 @@ def test_override_units(self, example_unitless_ds, no_unit_value):
)

def test_error_when_already_units(self, example_quantity_ds):
with raises_regex(ValueError, "already has units"):
example_quantity_ds.pint.quantify({"funds": "pounds"})
with pytest.raises(ValueError, match="already has units"):
example_quantity_ds.pint.quantify({"funds": "kg"})

def test_attach_no_units(self):
ds = xr.Dataset({"a": ("x", [1, 2, 3])})
quantified = ds.pint.quantify()
assert_identical(quantified, ds)
assert_units_equal(quantified, ds)

def test_attach_no_new_units(self):
ds = xr.Dataset({"a": ("x", unit_registry.Quantity([1, 2, 3], "m"))})
quantified = ds.pint.quantify()

assert_identical(quantified, ds)
assert_units_equal(quantified, ds)

def test_attach_same_units(self):
ds = xr.Dataset({"a": ("x", unit_registry.Quantity([1, 2, 3], "m"))})
quantified = ds.pint.quantify({"a": "m"})

assert_identical(quantified, ds)
assert_units_equal(quantified, ds)

def test_error_when_changing_units_dimension_coordinates(self):
ds = xr.Dataset(
coords={"x": ("x", [-1, 0, 1], {"units": unit_registry.Unit("m")})},
)
with pytest.raises(ValueError, match="already has units"):
ds.pint.quantify({"x": "s"})

def test_error_on_nonsense_units(self, example_unitless_ds):
ds = example_unitless_ds
Expand Down
14 changes: 14 additions & 0 deletions pint_xarray/tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ class TestArrayFunctions:
"already has units",
id="unit object on quantity",
),
pytest.param(
Unit("m"),
Quantity(np.array([0, 1]), "m"),
Quantity(np.array([0, 1]), "m"),
None,
id="unit object on quantity with same unit",
),
pytest.param(
Unit("mm"),
Quantity(np.array([0, 1]), "m"),
None,
"already has units",
id="unit object on quantity with similar unit",
),
),
)
def test_array_attach_units(self, data, unit, expected, match):
Expand Down