Skip to content

Commit 367d3f6

Browse files
authored
ERR: Raise ValueError when BaseIndexer start & end bounds are unequal length (#44497)
1 parent 279b91f commit 367d3f6

File tree

4 files changed

+71
-17
lines changed

4 files changed

+71
-17
lines changed

doc/source/whatsnew/v1.4.0.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,8 +679,8 @@ Groupby/resample/rolling
679679
- Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not calculating window bounds correctly for the first row when ``center=True`` and index is decreasing (:issue:`43927`)
680680
- Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` for centered datetimelike windows with uneven nanosecond (:issue:`43997`)
681681
- Bug in :meth:`GroupBy.nth` failing on ``axis=1`` (:issue:`43926`)
682-
- Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not respecting right bound on centered datetime-like windows, if the index contain duplicates (:issue:`#3944`)
683-
682+
- Fixed bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` not respecting right bound on centered datetime-like windows, if the index contain duplicates (:issue:`3944`)
683+
- Bug in :meth:`Series.rolling` and :meth:`DataFrame.rolling` when using a :class:`pandas.api.indexers.BaseIndexer` subclass that returned unequal start and end arrays would segfault instead of raising a ``ValueError`` (:issue:`44470`)
684684

685685
Reshaping
686686
^^^^^^^^^

pandas/core/window/ewm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,13 @@ def __init__(
417417
self.alpha,
418418
)
419419

420+
def _check_window_bounds(
421+
self, start: np.ndarray, end: np.ndarray, num_vals: int
422+
) -> None:
423+
# emw algorithms are iterative with each point
424+
# ExponentialMovingWindowIndexer "bounds" are the entire window
425+
pass
426+
420427
def _get_window_indexer(self) -> BaseIndexer:
421428
"""
422429
Return an indexer class that will compute the window start and end bounds

pandas/core/window/rolling.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,20 @@ def _validate(self) -> None:
227227
if self.method not in ["table", "single"]:
228228
raise ValueError("method must be 'table' or 'single")
229229

230+
def _check_window_bounds(
231+
self, start: np.ndarray, end: np.ndarray, num_vals: int
232+
) -> None:
233+
if len(start) != len(end):
234+
raise ValueError(
235+
f"start ({len(start)}) and end ({len(end)}) bounds must be the "
236+
f"same length"
237+
)
238+
elif len(start) != num_vals:
239+
raise ValueError(
240+
f"start and end bounds ({len(start)}) must be the same length "
241+
f"as the object ({num_vals})"
242+
)
243+
230244
def _create_data(self, obj: NDFrameT) -> NDFrameT:
231245
"""
232246
Split data into blocks & return conformed data.
@@ -311,10 +325,7 @@ def __iter__(self):
311325
center=self.center,
312326
closed=self.closed,
313327
)
314-
315-
assert len(start) == len(
316-
end
317-
), "these should be equal in length from get_window_bounds"
328+
self._check_window_bounds(start, end, len(obj))
318329

319330
for s, e in zip(start, end):
320331
result = obj.iloc[slice(s, e)]
@@ -565,9 +576,7 @@ def calc(x):
565576
center=self.center,
566577
closed=self.closed,
567578
)
568-
assert len(start) == len(
569-
end
570-
), "these should be equal in length from get_window_bounds"
579+
self._check_window_bounds(start, end, len(x))
571580

572581
return func(x, start, end, min_periods, *numba_args)
573582

@@ -608,6 +617,7 @@ def _numba_apply(
608617
center=self.center,
609618
closed=self.closed,
610619
)
620+
self._check_window_bounds(start, end, len(values))
611621
aggregator = executor.generate_shared_aggregator(
612622
func, engine_kwargs, numba_cache_key_str
613623
)
@@ -1544,10 +1554,7 @@ def cov_func(x, y):
15441554
center=self.center,
15451555
closed=self.closed,
15461556
)
1547-
1548-
assert len(start) == len(
1549-
end
1550-
), "these should be equal in length from get_window_bounds"
1557+
self._check_window_bounds(start, end, len(x_array))
15511558

15521559
with np.errstate(all="ignore"):
15531560
mean_x_y = window_aggregations.roll_mean(
@@ -1588,10 +1595,7 @@ def corr_func(x, y):
15881595
center=self.center,
15891596
closed=self.closed,
15901597
)
1591-
1592-
assert len(start) == len(
1593-
end
1594-
), "these should be equal in length from get_window_bounds"
1598+
self._check_window_bounds(start, end, len(x_array))
15951599

15961600
with np.errstate(all="ignore"):
15971601
mean_x_y = window_aggregations.roll_mean(

pandas/tests/window/test_base_indexer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,46 @@ def test_rolling_groupby_with_fixed_forward_many(group_keys, window_size):
452452
manual = manual.set_index(["a", "c"])["b"]
453453

454454
tm.assert_series_equal(result, manual)
455+
456+
457+
def test_unequal_start_end_bounds():
458+
class CustomIndexer(BaseIndexer):
459+
def get_window_bounds(self, num_values, min_periods, center, closed):
460+
return np.array([1]), np.array([1, 2])
461+
462+
indexer = CustomIndexer()
463+
roll = Series(1).rolling(indexer)
464+
match = "start"
465+
with pytest.raises(ValueError, match=match):
466+
roll.mean()
467+
468+
with pytest.raises(ValueError, match=match):
469+
next(iter(roll))
470+
471+
with pytest.raises(ValueError, match=match):
472+
roll.corr(pairwise=True)
473+
474+
with pytest.raises(ValueError, match=match):
475+
roll.cov(pairwise=True)
476+
477+
478+
def test_unequal_bounds_to_object():
479+
# GH 44470
480+
class CustomIndexer(BaseIndexer):
481+
def get_window_bounds(self, num_values, min_periods, center, closed):
482+
return np.array([1]), np.array([2])
483+
484+
indexer = CustomIndexer()
485+
roll = Series([1, 1]).rolling(indexer)
486+
match = "start and end"
487+
with pytest.raises(ValueError, match=match):
488+
roll.mean()
489+
490+
with pytest.raises(ValueError, match=match):
491+
next(iter(roll))
492+
493+
with pytest.raises(ValueError, match=match):
494+
roll.corr(pairwise=True)
495+
496+
with pytest.raises(ValueError, match=match):
497+
roll.cov(pairwise=True)

0 commit comments

Comments
 (0)