Skip to content

Commit 28be10b

Browse files
mroeschkeMatt Roeschke
and
Matt Roeschke
authored
CLN/TYPE: EWM (#34770)
* Move min_periods validation to init * Type signatures * Undo unnecessary casting * consolidate some cython type declariations * tighten up typing and black Co-authored-by: Matt Roeschke <[email protected]>
1 parent b7aff71 commit 28be10b

File tree

2 files changed

+45
-45
lines changed

2 files changed

+45
-45
lines changed

pandas/_libs/window/aggregations.pyx

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,7 +1759,7 @@ def roll_weighted_var(float64_t[:] values, float64_t[:] weights,
17591759
# Exponentially weighted moving average
17601760

17611761

1762-
def ewma(float64_t[:] vals, float64_t com, int adjust, bint ignore_na, int minp):
1762+
def ewma(float64_t[:] vals, float64_t com, bint adjust, bint ignore_na, int minp):
17631763
"""
17641764
Compute exponentially-weighted moving average using center-of-mass.
17651765
@@ -1777,17 +1777,14 @@ def ewma(float64_t[:] vals, float64_t com, int adjust, bint ignore_na, int minp)
17771777
"""
17781778

17791779
cdef:
1780-
Py_ssize_t N = len(vals)
1780+
Py_ssize_t i, nobs, N = len(vals)
17811781
ndarray[float64_t] output = np.empty(N, dtype=float)
17821782
float64_t alpha, old_wt_factor, new_wt, weighted_avg, old_wt, cur
1783-
Py_ssize_t i, nobs
17841783
bint is_observation
17851784

17861785
if N == 0:
17871786
return output
17881787

1789-
minp = max(minp, 1)
1790-
17911788
alpha = 1. / (1. + com)
17921789
old_wt_factor = 1. - alpha
17931790
new_wt = 1. if adjust else alpha
@@ -1831,7 +1828,7 @@ def ewma(float64_t[:] vals, float64_t com, int adjust, bint ignore_na, int minp)
18311828

18321829

18331830
def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
1834-
float64_t com, int adjust, bint ignore_na, int minp, int bias):
1831+
float64_t com, bint adjust, bint ignore_na, int minp, bint bias):
18351832
"""
18361833
Compute exponentially-weighted moving variance using center-of-mass.
18371834
@@ -1851,11 +1848,10 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
18511848
"""
18521849

18531850
cdef:
1854-
Py_ssize_t N = len(input_x), M = len(input_y)
1851+
Py_ssize_t i, nobs, N = len(input_x), M = len(input_y)
18551852
float64_t alpha, old_wt_factor, new_wt, mean_x, mean_y, cov
18561853
float64_t sum_wt, sum_wt2, old_wt, cur_x, cur_y, old_mean_x, old_mean_y
18571854
float64_t numerator, denominator
1858-
Py_ssize_t i, nobs
18591855
ndarray[float64_t] output
18601856
bint is_observation
18611857

@@ -1866,8 +1862,6 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
18661862
if N == 0:
18671863
return output
18681864

1869-
minp = max(minp, 1)
1870-
18711865
alpha = 1. / (1. + com)
18721866
old_wt_factor = 1. - alpha
18731867
new_wt = 1. if adjust else alpha

pandas/core/window/ewm.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from functools import partial
22
from textwrap import dedent
3+
from typing import Optional, Union
34

45
import numpy as np
56

67
import pandas._libs.window.aggregations as window_aggregations
8+
from pandas._typing import FrameOrSeries
79
from pandas.compat.numpy import function as nv
810
from pandas.util._decorators import Appender, Substitution
911

@@ -24,7 +26,12 @@
2426
"""
2527

2628

27-
def get_center_of_mass(comass, span, halflife, alpha) -> float:
29+
def get_center_of_mass(
30+
comass: Optional[float],
31+
span: Optional[float],
32+
halflife: Optional[float],
33+
alpha: Optional[float],
34+
) -> float:
2835
valid_count = com.count_not_none(comass, span, halflife, alpha)
2936
if valid_count > 1:
3037
raise ValueError("comass, span, halflife, and alpha are mutually exclusive")
@@ -114,7 +121,7 @@ class EWM(_Rolling):
114121
used in calculating the final weighted average of
115122
[:math:`x_0`, None, :math:`x_2`] are :math:`1-\alpha` and :math:`1` if
116123
``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if ``adjust=False``.
117-
axis : {0 or 'index', 1 or 'columns'}, default 0
124+
axis : {0, 1}, default 0
118125
The axis to use. The value 0 identifies the rows, and 1
119126
identifies the columns.
120127
@@ -159,18 +166,18 @@ class EWM(_Rolling):
159166
def __init__(
160167
self,
161168
obj,
162-
com=None,
163-
span=None,
164-
halflife=None,
165-
alpha=None,
166-
min_periods=0,
167-
adjust=True,
168-
ignore_na=False,
169-
axis=0,
169+
com: Optional[float] = None,
170+
span: Optional[float] = None,
171+
halflife: Optional[float] = None,
172+
alpha: Optional[float] = None,
173+
min_periods: int = 0,
174+
adjust: bool = True,
175+
ignore_na: bool = False,
176+
axis: int = 0,
170177
):
171178
self.obj = obj
172179
self.com = get_center_of_mass(com, span, halflife, alpha)
173-
self.min_periods = min_periods
180+
self.min_periods = max(int(min_periods), 1)
174181
self.adjust = adjust
175182
self.ignore_na = ignore_na
176183
self.axis = axis
@@ -274,16 +281,16 @@ def mean(self, *args, **kwargs):
274281
window_func = partial(
275282
window_func,
276283
com=self.com,
277-
adjust=int(self.adjust),
284+
adjust=self.adjust,
278285
ignore_na=self.ignore_na,
279-
minp=int(self.min_periods),
286+
minp=self.min_periods,
280287
)
281288
return self._apply(window_func)
282289

283290
@Substitution(name="ewm", func_name="std")
284291
@Appender(_doc_template)
285292
@Appender(_bias_template)
286-
def std(self, bias=False, *args, **kwargs):
293+
def std(self, bias: bool = False, *args, **kwargs):
287294
"""
288295
Exponential weighted moving stddev.
289296
"""
@@ -295,28 +302,28 @@ def std(self, bias=False, *args, **kwargs):
295302
@Substitution(name="ewm", func_name="var")
296303
@Appender(_doc_template)
297304
@Appender(_bias_template)
298-
def var(self, bias=False, *args, **kwargs):
305+
def var(self, bias: bool = False, *args, **kwargs):
299306
"""
300307
Exponential weighted moving variance.
301308
"""
302309
nv.validate_window_func("var", args, kwargs)
303310

304311
def f(arg):
305312
return window_aggregations.ewmcov(
306-
arg,
307-
arg,
308-
self.com,
309-
int(self.adjust),
310-
int(self.ignore_na),
311-
int(self.min_periods),
312-
int(bias),
313+
arg, arg, self.com, self.adjust, self.ignore_na, self.min_periods, bias,
313314
)
314315

315316
return self._apply(f)
316317

317318
@Substitution(name="ewm", func_name="cov")
318319
@Appender(_doc_template)
319-
def cov(self, other=None, pairwise=None, bias=False, **kwargs):
320+
def cov(
321+
self,
322+
other: Optional[Union[np.ndarray, FrameOrSeries]] = None,
323+
pairwise: Optional[bool] = None,
324+
bias: bool = False,
325+
**kwargs,
326+
):
320327
"""
321328
Exponential weighted sample covariance.
322329
@@ -350,10 +357,10 @@ def _get_cov(X, Y):
350357
X._prep_values(),
351358
Y._prep_values(),
352359
self.com,
353-
int(self.adjust),
354-
int(self.ignore_na),
355-
int(self.min_periods),
356-
int(bias),
360+
self.adjust,
361+
self.ignore_na,
362+
self.min_periods,
363+
bias,
357364
)
358365
return X._wrap_result(cov)
359366

@@ -363,7 +370,12 @@ def _get_cov(X, Y):
363370

364371
@Substitution(name="ewm", func_name="corr")
365372
@Appender(_doc_template)
366-
def corr(self, other=None, pairwise=None, **kwargs):
373+
def corr(
374+
self,
375+
other: Optional[Union[np.ndarray, FrameOrSeries]] = None,
376+
pairwise: Optional[bool] = None,
377+
**kwargs,
378+
):
367379
"""
368380
Exponential weighted sample correlation.
369381
@@ -394,13 +406,7 @@ def _get_corr(X, Y):
394406

395407
def _cov(x, y):
396408
return window_aggregations.ewmcov(
397-
x,
398-
y,
399-
self.com,
400-
int(self.adjust),
401-
int(self.ignore_na),
402-
int(self.min_periods),
403-
1,
409+
x, y, self.com, self.adjust, self.ignore_na, self.min_periods, 1,
404410
)
405411

406412
x_values = X._prep_values()

0 commit comments

Comments
 (0)