1
1
from functools import partial
2
2
from textwrap import dedent
3
+ from typing import Optional , Union
3
4
4
5
import numpy as np
5
6
6
7
import pandas ._libs .window .aggregations as window_aggregations
8
+ from pandas ._typing import FrameOrSeries
7
9
from pandas .compat .numpy import function as nv
8
10
from pandas .util ._decorators import Appender , Substitution
9
11
24
26
"""
25
27
26
28
27
- def get_center_of_mass (comass , span , halflife , alpha ) -> float :
29
+ def get_center_of_mass (
30
+ comass : Optional [float ],
31
+ span : Optional [float ],
32
+ halflife : Optional [float ],
33
+ alpha : Optional [float ],
34
+ ) -> float :
28
35
valid_count = com .count_not_none (comass , span , halflife , alpha )
29
36
if valid_count > 1 :
30
37
raise ValueError ("comass, span, halflife, and alpha are mutually exclusive" )
@@ -114,7 +121,7 @@ class EWM(_Rolling):
114
121
used in calculating the final weighted average of
115
122
[:math:`x_0`, None, :math:`x_2`] are :math:`1-\alpha` and :math:`1` if
116
123
``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if ``adjust=False``.
117
- axis : {0 or 'index' , 1 or 'columns' }, default 0
124
+ axis : {0, 1}, default 0
118
125
The axis to use. The value 0 identifies the rows, and 1
119
126
identifies the columns.
120
127
@@ -159,18 +166,18 @@ class EWM(_Rolling):
159
166
def __init__ (
160
167
self ,
161
168
obj ,
162
- com = None ,
163
- span = None ,
164
- halflife = None ,
165
- alpha = None ,
166
- min_periods = 0 ,
167
- adjust = True ,
168
- ignore_na = False ,
169
- axis = 0 ,
169
+ com : Optional [ float ] = None ,
170
+ span : Optional [ float ] = None ,
171
+ halflife : Optional [ float ] = None ,
172
+ alpha : Optional [ float ] = None ,
173
+ min_periods : int = 0 ,
174
+ adjust : bool = True ,
175
+ ignore_na : bool = False ,
176
+ axis : int = 0 ,
170
177
):
171
178
self .obj = obj
172
179
self .com = get_center_of_mass (com , span , halflife , alpha )
173
- self .min_periods = min_periods
180
+ self .min_periods = max ( int ( min_periods ), 1 )
174
181
self .adjust = adjust
175
182
self .ignore_na = ignore_na
176
183
self .axis = axis
@@ -274,16 +281,16 @@ def mean(self, *args, **kwargs):
274
281
window_func = partial (
275
282
window_func ,
276
283
com = self .com ,
277
- adjust = int ( self .adjust ) ,
284
+ adjust = self .adjust ,
278
285
ignore_na = self .ignore_na ,
279
- minp = int ( self .min_periods ) ,
286
+ minp = self .min_periods ,
280
287
)
281
288
return self ._apply (window_func )
282
289
283
290
@Substitution (name = "ewm" , func_name = "std" )
284
291
@Appender (_doc_template )
285
292
@Appender (_bias_template )
286
- def std (self , bias = False , * args , ** kwargs ):
293
+ def std (self , bias : bool = False , * args , ** kwargs ):
287
294
"""
288
295
Exponential weighted moving stddev.
289
296
"""
@@ -295,28 +302,28 @@ def std(self, bias=False, *args, **kwargs):
295
302
@Substitution (name = "ewm" , func_name = "var" )
296
303
@Appender (_doc_template )
297
304
@Appender (_bias_template )
298
- def var (self , bias = False , * args , ** kwargs ):
305
+ def var (self , bias : bool = False , * args , ** kwargs ):
299
306
"""
300
307
Exponential weighted moving variance.
301
308
"""
302
309
nv .validate_window_func ("var" , args , kwargs )
303
310
304
311
def f (arg ):
305
312
return window_aggregations .ewmcov (
306
- arg ,
307
- arg ,
308
- self .com ,
309
- int (self .adjust ),
310
- int (self .ignore_na ),
311
- int (self .min_periods ),
312
- int (bias ),
313
+ arg , arg , self .com , self .adjust , self .ignore_na , self .min_periods , bias ,
313
314
)
314
315
315
316
return self ._apply (f )
316
317
317
318
@Substitution (name = "ewm" , func_name = "cov" )
318
319
@Appender (_doc_template )
319
- def cov (self , other = None , pairwise = None , bias = False , ** kwargs ):
320
+ def cov (
321
+ self ,
322
+ other : Optional [Union [np .ndarray , FrameOrSeries ]] = None ,
323
+ pairwise : Optional [bool ] = None ,
324
+ bias : bool = False ,
325
+ ** kwargs ,
326
+ ):
320
327
"""
321
328
Exponential weighted sample covariance.
322
329
@@ -350,10 +357,10 @@ def _get_cov(X, Y):
350
357
X ._prep_values (),
351
358
Y ._prep_values (),
352
359
self .com ,
353
- int ( self .adjust ) ,
354
- int ( self .ignore_na ) ,
355
- int ( self .min_periods ) ,
356
- int ( bias ) ,
360
+ self .adjust ,
361
+ self .ignore_na ,
362
+ self .min_periods ,
363
+ bias ,
357
364
)
358
365
return X ._wrap_result (cov )
359
366
@@ -363,7 +370,12 @@ def _get_cov(X, Y):
363
370
364
371
@Substitution (name = "ewm" , func_name = "corr" )
365
372
@Appender (_doc_template )
366
- def corr (self , other = None , pairwise = None , ** kwargs ):
373
+ def corr (
374
+ self ,
375
+ other : Optional [Union [np .ndarray , FrameOrSeries ]] = None ,
376
+ pairwise : Optional [bool ] = None ,
377
+ ** kwargs ,
378
+ ):
367
379
"""
368
380
Exponential weighted sample correlation.
369
381
@@ -394,13 +406,7 @@ def _get_corr(X, Y):
394
406
395
407
def _cov (x , y ):
396
408
return window_aggregations .ewmcov (
397
- x ,
398
- y ,
399
- self .com ,
400
- int (self .adjust ),
401
- int (self .ignore_na ),
402
- int (self .min_periods ),
403
- 1 ,
409
+ x , y , self .com , self .adjust , self .ignore_na , self .min_periods , 1 ,
404
410
)
405
411
406
412
x_values = X ._prep_values ()
0 commit comments