Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 68f6bf4

Browse files
Reimplement Series.cov using prange (#451)
1 parent 9f63b4a commit 68f6bf4

File tree

3 files changed

+39
-25
lines changed

3 files changed

+39
-25
lines changed

sdc/datatypes/hpat_pandas_series_functions.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from numba.extending import intrinsic
4343
from numba import (types, numpy_support, cgutils)
4444
from numba.typed import Dict
45+
from numba import prange
4546

4647
import sdc
4748
import sdc.datatypes.common_functions as common_functions
@@ -4638,32 +4639,31 @@ def hpat_pandas_series_cov(self, other, min_periods=None):
46384639

46394640
def hpat_pandas_series_cov_impl(self, other, min_periods=None):
46404641

4641-
if min_periods is None:
4642-
min_periods = 1
4643-
4644-
if len(self._data) == 0 or len(other._data) == 0:
4645-
return numpy.nan
4646-
4647-
self_arr = self._data[:min(len(self._data), len(other._data))]
4648-
other_arr = other._data[:min(len(self._data), len(other._data))]
4642+
if min_periods is None or min_periods < 2:
4643+
min_periods = 2
46494644

4650-
invalid = numpy.isnan(self_arr) | numpy.isnan(other_arr)
4651-
if invalid.any():
4652-
self_arr = self_arr[~invalid]
4653-
other_arr = other_arr[~invalid]
4645+
min_len = min(len(self._data), len(other._data))
46544646

4655-
if len(self_arr) < min_periods:
4647+
if min_len == 0:
46564648
return numpy.nan
46574649

4658-
new_self = pandas.Series(self_arr)
4659-
4660-
ma = new_self.mean()
4661-
mb = other.mean()
4662-
4663-
if numpy.isinf(mb):
4650+
other_sum = 0.
4651+
self_sum = 0.
4652+
self_other_sum = 0.
4653+
total_count = 0
4654+
for i in prange(min_len):
4655+
s = self._data[i]
4656+
o = other._data[i]
4657+
if not (numpy.isnan(s) or numpy.isnan(o)):
4658+
self_sum += s
4659+
other_sum += o
4660+
self_other_sum += s*o
4661+
total_count += 1
4662+
4663+
if total_count < min_periods:
46644664
return numpy.nan
46654665

4666-
return ((self_arr - ma) * (other_arr - mb)).sum() / (new_self.count() - 1.0)
4666+
return (self_other_sum - self_sum*other_sum/total_count)/(total_count - 1)
46674667

46684668
return hpat_pandas_series_cov_impl
46694669

sdc/tests/test_series.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5031,9 +5031,10 @@ def test_series_cov_impl(S1, S2, min_periods=None):
50315031
S1 = pd.Series(input_data1)
50325032
S2 = pd.Series(input_data2)
50335033
for period in [None, 2, 1, 8, -4]:
5034-
result_ref = test_series_cov_impl(S1, S2, min_periods=period)
5035-
result = hpat_func(S1, S2, min_periods=period)
5036-
np.testing.assert_allclose(result, result_ref)
5034+
with self.subTest(input_data1=input_data1, input_data2=input_data2, min_periods=period):
5035+
result_ref = test_series_cov_impl(S1, S2, min_periods=period)
5036+
result = hpat_func(S1, S2, min_periods=period)
5037+
np.testing.assert_allclose(result, result_ref)
50375038

50385039
@skip_sdc_jit('Series.cov() parameter "min_periods" unsupported')
50395040
def test_series_cov_unsupported_dtype(self):

sdc/tests/tests_perf/test_perf_series.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def usecase_series_dropna(input_data):
220220
finish_time = time.time()
221221

222222
return finish_time - start_time, res
223-
224-
223+
224+
225225
def usecase_series_chain_add_and_sum(A, B):
226226
start_time = time.time()
227227
res = (A + B).sum()
@@ -259,6 +259,15 @@ def usecase_series_isna(input_data):
259259
return res_time, res
260260

261261

262+
def usecase_series_cov(A, B):
263+
start_time = time.time()
264+
res = A.cov(B)
265+
finish_time = time.time()
266+
res_time = finish_time - start_time
267+
268+
return res_time, res
269+
270+
262271
# python -m sdc.runtests sdc.tests.tests_perf.test_perf_series.TestSeriesMethods
263272
class TestSeriesMethods(TestBase):
264273
@classmethod
@@ -292,6 +301,7 @@ def setUpClass(cls):
292301
'series_astype_int': [2 * 10 ** 7],
293302
'series_fillna': [2 * 10 ** 7],
294303
'series_isna': [2 * 10 ** 7],
304+
'series_cov': [10 ** 8]
295305
}
296306

297307
def _test_jitted(self, pyfunc, record, *args, **kwargs):
@@ -434,3 +444,6 @@ def test_series_float_fillna(self):
434444

435445
def test_series_float_isna(self):
436446
self._test_case(usecase_series_fillna, 'series_isna')
447+
448+
def test_series_float_cov(self):
449+
self._test_series_binary_operations(usecase_series_cov, 'series_cov')

0 commit comments

Comments
 (0)