Skip to content

Commit a49aac9

Browse files
committed
Handle timedelta serde
1 parent a2edd4a commit a49aac9

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

src/zarr/core/metadata/v2.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ def _parse_structured_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
310310
raise ValueError(f"Fill_value {fill_value} is not valid for dtype {dtype}.") from e
311311

312312

313+
def _parse_timedelta(td: dict[str, Any]) -> np.timedelta64:
314+
if td["value"] is None:
315+
return np.timedelta64("NaT")
316+
else:
317+
return np.timedelta64(int(td["value"]), td["unit"])
318+
319+
313320
def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
314321
"""
315322
Parse a potential fill value into a value that is compatible with the provided dtype.
@@ -329,13 +336,18 @@ def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
329336
if fill_value is None or dtype.hasobject:
330337
# Pass through None or if dtype is object
331338
pass
332-
elif dtype.kind in "M":
339+
elif dtype.kind == "M": # datetime
333340
# Check for both string "NaT" and the int64 representation of NaT
334341
if fill_value == "NaT" or fill_value == np.iinfo(np.int64).min:
335342
fill_value = dtype.type("NaT")
336343
else:
337344
fill_value = np.array(fill_value, dtype=dtype)[()]
338345
# Fall through for non-NaT datetime/timedelta values (handled below)
346+
elif dtype.kind == "m": # timedelta
347+
if isinstance(fill_value, dict):
348+
return _parse_timedelta(fill_value)
349+
else: # if raw value is passed rather than unit-based serialization
350+
return np.timedelta64(fill_value)
339351
elif dtype.fields is not None:
340352
# the dtype is structured (has multiple fields), so the fill_value might be a
341353
# compound value (e.g., a tuple or dict) that needs field-wise processing.
@@ -373,6 +385,14 @@ def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
373385
return fill_value
374386

375387

388+
def _serialize_timedelta(td: np.timedelta64) -> JSON:
389+
if np.isnat(td):
390+
return {"value": None, "unit": None}
391+
else:
392+
val, unit = int(td.astype(int)), td.dtype.name.split("[")[-1][:-1]
393+
return {"value": val, "unit": unit}
394+
395+
376396
def _serialize_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> JSON:
377397
serialized: JSON
378398

@@ -383,7 +403,9 @@ def _serialize_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> JSON:
383403
# that mypy isn't aware of. The fact that we have S or V dtype here
384404
# means we should have a bytes-type fill_value.
385405
serialized = base64.standard_b64encode(cast(bytes, fill_value)).decode("ascii")
386-
elif isinstance(fill_value, np.datetime64):
406+
elif dtype.kind == "m":
407+
serialized = _serialize_timedelta(fill_value)
408+
elif dtype.kind == "M":
387409
serialized = np.datetime_as_string(fill_value)
388410
elif isinstance(fill_value, numbers.Integral):
389411
serialized = int(fill_value)
@@ -423,7 +445,9 @@ def _default_fill_value(dtype: np.dtype[Any]) -> Any:
423445
return b""
424446
elif dtype.kind in "UO":
425447
return ""
426-
elif dtype.kind in "Mm":
448+
elif dtype.kind in "m": # timedelta64
449+
return np.timedelta64("NaT")
450+
elif dtype.kind in "M": # datetime64
427451
return dtype.type("nat")
428452
elif dtype.kind == "V":
429453
if dtype.fields is not None:

src/zarr/testing/strategies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def v2_dtypes() -> st.SearchStrategy[np.dtype]:
6464
| npst.byte_string_dtypes(endianness="=")
6565
| npst.unicode_string_dtypes(endianness="=")
6666
| npst.datetime64_dtypes(endianness="=")
67-
# | npst.timedelta64_dtypes()
67+
| npst.timedelta64_dtypes(endianness="=")
6868
)
6969

7070

tests/test_metadata/test_v2.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,29 @@ def test_parse_v2_fill_value_non_nat(dtype_str: str) -> None:
340340
assert meta.fill_value.dtype == dtype
341341

342342

343+
@pytest.mark.parametrize("dtype_str", ["timedelta64[ms]"])
344+
@pytest.mark.parametrize("fill_value_raw", [12345, np.timedelta64(12345, "ms")]) # Example values
345+
def test_parse_v2_fill_value_non_nat_raw_timedelta(dtype_str: str, fill_value_raw: Any) -> None:
346+
dtype = np.dtype(dtype_str)
347+
expected_value = (
348+
np.timedelta64(fill_value_raw, "ms") if isinstance(fill_value_raw, int) else fill_value_raw
349+
)
350+
351+
metadata_dict = {
352+
"zarr_format": 2,
353+
"shape": (10,),
354+
"chunks": (5,),
355+
"dtype": dtype.str,
356+
"compressor": None,
357+
"filters": None,
358+
"fill_value": fill_value_raw,
359+
"order": "C",
360+
}
361+
meta = ArrayV2Metadata.from_dict(metadata_dict)
362+
assert isinstance(meta.fill_value, np.timedelta64)
363+
assert meta.fill_value == expected_value
364+
365+
343366
def test_from_dict_extra_fields() -> None:
344367
data = {
345368
"_nczarr_array": {"dimrefs": ["/dim1", "/dim2"], "storage": "chunked"},

0 commit comments

Comments
 (0)