Skip to content

Commit 60c320c

Browse files
bpo-37295: Optimize math.comb() and math.perm() (GH-29090)
For very large numbers use divide-and-conquer algorithm for getting benefit of Karatsuba multiplication of large numbers. Do calculations completely in C unsigned long long instead of Python integers if possible.
1 parent 628abe4 commit 60c320c

File tree

3 files changed

+198
-93
lines changed

3 files changed

+198
-93
lines changed

Doc/whatsnew/3.11.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,11 @@ Optimizations
351351
* Pure ASCII strings are now normalized in constant time by :func:`unicodedata.normalize`.
352352
(Contributed by Dong-hee Na in :issue:`44987`.)
353353

354+
* :mod:`math` functions :func:`~math.comb` and :func:`~math.perm` are now up
355+
to 10 times or more faster for large arguments (the speed up is larger for
356+
larger *k*).
357+
(Contributed by Serhiy Storchaka in :issue:`37295`.)
358+
354359

355360
CPython bytecode changes
356361
========================
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Optimize :func:`math.comb` and :func:`math.perm`.

Modules/mathmodule.c

Lines changed: 192 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
32213221
}
32223222

32233223

3224+
/* Number of permutations and combinations.
3225+
* P(n, k) = n! / (n-k)!
3226+
* C(n, k) = P(n, k) / k!
3227+
*/
3228+
3229+
/* Calculate C(n, k) for n in the 63-bit range. */
3230+
static PyObject *
3231+
perm_comb_small(unsigned long long n, unsigned long long k, int iscomb)
3232+
{
3233+
/* long long is at least 64 bit */
3234+
static const unsigned long long fast_comb_limits[] = {
3235+
0, ULLONG_MAX, 4294967296ULL, 3329022, 102570, 13467, 3612, 1449, // 0-7
3236+
746, 453, 308, 227, 178, 147, 125, 110, // 8-15
3237+
99, 90, 84, 79, 75, 72, 69, 68, // 16-23
3238+
66, 65, 64, 63, 63, 62, 62, 62, // 24-31
3239+
};
3240+
static const unsigned long long fast_perm_limits[] = {
3241+
0, ULLONG_MAX, 4294967296ULL, 2642246, 65537, 7133, 1627, 568, // 0-7
3242+
259, 142, 88, 61, 45, 36, 30, // 8-14
3243+
};
3244+
3245+
if (k == 0) {
3246+
return PyLong_FromLong(1);
3247+
}
3248+
3249+
/* For small enough n and k the result fits in the 64-bit range and can
3250+
* be calculated without allocating intermediate PyLong objects. */
3251+
if (iscomb
3252+
? (k < Py_ARRAY_LENGTH(fast_comb_limits)
3253+
&& n <= fast_comb_limits[k])
3254+
: (k < Py_ARRAY_LENGTH(fast_perm_limits)
3255+
&& n <= fast_perm_limits[k]))
3256+
{
3257+
unsigned long long result = n;
3258+
if (iscomb) {
3259+
for (unsigned long long i = 1; i < k;) {
3260+
result *= --n;
3261+
result /= ++i;
3262+
}
3263+
}
3264+
else {
3265+
for (unsigned long long i = 1; i < k;) {
3266+
result *= --n;
3267+
++i;
3268+
}
3269+
}
3270+
return PyLong_FromUnsignedLongLong(result);
3271+
}
3272+
3273+
/* For larger n use recursive formula. */
3274+
/* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
3275+
unsigned long long j = k / 2;
3276+
PyObject *a, *b;
3277+
a = perm_comb_small(n, j, iscomb);
3278+
if (a == NULL) {
3279+
return NULL;
3280+
}
3281+
b = perm_comb_small(n - j, k - j, iscomb);
3282+
if (b == NULL) {
3283+
goto error;
3284+
}
3285+
Py_SETREF(a, PyNumber_Multiply(a, b));
3286+
Py_DECREF(b);
3287+
if (iscomb && a != NULL) {
3288+
b = perm_comb_small(k, j, 1);
3289+
if (b == NULL) {
3290+
goto error;
3291+
}
3292+
Py_SETREF(a, PyNumber_FloorDivide(a, b));
3293+
Py_DECREF(b);
3294+
}
3295+
return a;
3296+
3297+
error:
3298+
Py_DECREF(a);
3299+
return NULL;
3300+
}
3301+
3302+
/* Calculate P(n, k) or C(n, k) using recursive formulas.
3303+
* It is more efficient than sequential multiplication thanks to
3304+
* Karatsuba multiplication.
3305+
*/
3306+
static PyObject *
3307+
perm_comb(PyObject *n, unsigned long long k, int iscomb)
3308+
{
3309+
if (k == 0) {
3310+
return PyLong_FromLong(1);
3311+
}
3312+
if (k == 1) {
3313+
Py_INCREF(n);
3314+
return n;
3315+
}
3316+
3317+
/* P(n, k) = P(n, j) * P(n-j, k-j) */
3318+
/* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
3319+
unsigned long long j = k / 2;
3320+
PyObject *a, *b;
3321+
a = perm_comb(n, j, iscomb);
3322+
if (a == NULL) {
3323+
return NULL;
3324+
}
3325+
PyObject *t = PyLong_FromUnsignedLongLong(j);
3326+
if (t == NULL) {
3327+
goto error;
3328+
}
3329+
n = PyNumber_Subtract(n, t);
3330+
Py_DECREF(t);
3331+
if (n == NULL) {
3332+
goto error;
3333+
}
3334+
b = perm_comb(n, k - j, iscomb);
3335+
Py_DECREF(n);
3336+
if (b == NULL) {
3337+
goto error;
3338+
}
3339+
Py_SETREF(a, PyNumber_Multiply(a, b));
3340+
Py_DECREF(b);
3341+
if (iscomb && a != NULL) {
3342+
b = perm_comb_small(k, j, 1);
3343+
if (b == NULL) {
3344+
goto error;
3345+
}
3346+
Py_SETREF(a, PyNumber_FloorDivide(a, b));
3347+
Py_DECREF(b);
3348+
}
3349+
return a;
3350+
3351+
error:
3352+
Py_DECREF(a);
3353+
return NULL;
3354+
}
3355+
32243356
/*[clinic input]
32253357
math.perm
32263358
@@ -3244,9 +3376,9 @@ static PyObject *
32443376
math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
32453377
/*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
32463378
{
3247-
PyObject *result = NULL, *factor = NULL;
3379+
PyObject *result = NULL;
32483380
int overflow, cmp;
3249-
long long i, factors;
3381+
long long ki, ni;
32503382

32513383
if (k == Py_None) {
32523384
return math_factorial(module, n);
@@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
32603392
Py_DECREF(n);
32613393
return NULL;
32623394
}
3395+
assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
32633396

32643397
if (Py_SIZE(n) < 0) {
32653398
PyErr_SetString(PyExc_ValueError,
@@ -3281,57 +3414,38 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
32813414
goto error;
32823415
}
32833416

3284-
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
3417+
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
3418+
assert(overflow >= 0 && !PyErr_Occurred());
32853419
if (overflow > 0) {
32863420
PyErr_Format(PyExc_OverflowError,
32873421
"k must not exceed %lld",
32883422
LLONG_MAX);
32893423
goto error;
32903424
}
3291-
else if (factors == -1) {
3292-
/* k is nonnegative, so a return value of -1 can only indicate error */
3293-
goto error;
3294-
}
3425+
assert(ki >= 0);
32953426

3296-
if (factors == 0) {
3297-
result = PyLong_FromLong(1);
3298-
goto done;
3427+
ni = PyLong_AsLongLongAndOverflow(n, &overflow);
3428+
assert(overflow >= 0 && !PyErr_Occurred());
3429+
if (!overflow && ki > 1) {
3430+
assert(ni >= 0);
3431+
result = perm_comb_small((unsigned long long)ni,
3432+
(unsigned long long)ki, 0);
32993433
}
3300-
3301-
result = n;
3302-
Py_INCREF(result);
3303-
if (factors == 1) {
3304-
goto done;
3305-
}
3306-
3307-
factor = Py_NewRef(n);
3308-
PyObject *one = _PyLong_GetOne(); // borrowed ref
3309-
for (i = 1; i < factors; ++i) {
3310-
Py_SETREF(factor, PyNumber_Subtract(factor, one));
3311-
if (factor == NULL) {
3312-
goto error;
3313-
}
3314-
Py_SETREF(result, PyNumber_Multiply(result, factor));
3315-
if (result == NULL) {
3316-
goto error;
3317-
}
3434+
else {
3435+
result = perm_comb(n, (unsigned long long)ki, 0);
33183436
}
3319-
Py_DECREF(factor);
33203437

33213438
done:
33223439
Py_DECREF(n);
33233440
Py_DECREF(k);
33243441
return result;
33253442

33263443
error:
3327-
Py_XDECREF(factor);
3328-
Py_XDECREF(result);
33293444
Py_DECREF(n);
33303445
Py_DECREF(k);
33313446
return NULL;
33323447
}
33333448

3334-
33353449
/*[clinic input]
33363450
math.comb
33373451
@@ -3357,9 +3471,9 @@ static PyObject *
33573471
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
33583472
/*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
33593473
{
3360-
PyObject *result = NULL, *factor = NULL, *temp;
3474+
PyObject *result = NULL, *temp;
33613475
int overflow, cmp;
3362-
long long i, factors;
3476+
long long ki, ni;
33633477

33643478
n = PyNumber_Index(n);
33653479
if (n == NULL) {
@@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
33703484
Py_DECREF(n);
33713485
return NULL;
33723486
}
3487+
assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
33733488

33743489
if (Py_SIZE(n) < 0) {
33753490
PyErr_SetString(PyExc_ValueError,
@@ -3382,82 +3497,66 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
33823497
goto error;
33833498
}
33843499

3385-
/* k = min(k, n - k) */
3386-
temp = PyNumber_Subtract(n, k);
3387-
if (temp == NULL) {
3388-
goto error;
3389-
}
3390-
if (Py_SIZE(temp) < 0) {
3391-
Py_DECREF(temp);
3392-
result = PyLong_FromLong(0);
3393-
goto done;
3394-
}
3395-
cmp = PyObject_RichCompareBool(temp, k, Py_LT);
3396-
if (cmp > 0) {
3397-
Py_SETREF(k, temp);
3500+
ni = PyLong_AsLongLongAndOverflow(n, &overflow);
3501+
assert(overflow >= 0 && !PyErr_Occurred());
3502+
if (!overflow) {
3503+
assert(ni >= 0);
3504+
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
3505+
assert(overflow >= 0 && !PyErr_Occurred());
3506+
if (overflow || ki > ni) {
3507+
result = PyLong_FromLong(0);
3508+
goto done;
3509+
}
3510+
assert(ki >= 0);
3511+
ki = Py_MIN(ki, ni - ki);
3512+
if (ki > 1) {
3513+
result = perm_comb_small((unsigned long long)ni,
3514+
(unsigned long long)ki, 1);
3515+
goto done;
3516+
}
3517+
/* For k == 1 just return the original n in perm_comb(). */
33983518
}
33993519
else {
3400-
Py_DECREF(temp);
3401-
if (cmp < 0) {
3520+
/* k = min(k, n - k) */
3521+
temp = PyNumber_Subtract(n, k);
3522+
if (temp == NULL) {
34023523
goto error;
34033524
}
3404-
}
3405-
3406-
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
3407-
if (overflow > 0) {
3408-
PyErr_Format(PyExc_OverflowError,
3409-
"min(n - k, k) must not exceed %lld",
3410-
LLONG_MAX);
3411-
goto error;
3412-
}
3413-
if (factors == -1) {
3414-
/* k is nonnegative, so a return value of -1 can only indicate error */
3415-
goto error;
3416-
}
3417-
3418-
if (factors == 0) {
3419-
result = PyLong_FromLong(1);
3420-
goto done;
3421-
}
3422-
3423-
result = n;
3424-
Py_INCREF(result);
3425-
if (factors == 1) {
3426-
goto done;
3427-
}
3428-
3429-
factor = Py_NewRef(n);
3430-
PyObject *one = _PyLong_GetOne(); // borrowed ref
3431-
for (i = 1; i < factors; ++i) {
3432-
Py_SETREF(factor, PyNumber_Subtract(factor, one));
3433-
if (factor == NULL) {
3434-
goto error;
3525+
if (Py_SIZE(temp) < 0) {
3526+
Py_DECREF(temp);
3527+
result = PyLong_FromLong(0);
3528+
goto done;
34353529
}
3436-
Py_SETREF(result, PyNumber_Multiply(result, factor));
3437-
if (result == NULL) {
3438-
goto error;
3530+
cmp = PyObject_RichCompareBool(temp, k, Py_LT);
3531+
if (cmp > 0) {
3532+
Py_SETREF(k, temp);
34393533
}
3440-
3441-
temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
3442-
if (temp == NULL) {
3443-
goto error;
3534+
else {
3535+
Py_DECREF(temp);
3536+
if (cmp < 0) {
3537+
goto error;
3538+
}
34443539
}
3445-
Py_SETREF(result, PyNumber_FloorDivide(result, temp));
3446-
Py_DECREF(temp);
3447-
if (result == NULL) {
3540+
3541+
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
3542+
assert(overflow >= 0 && !PyErr_Occurred());
3543+
if (overflow) {
3544+
PyErr_Format(PyExc_OverflowError,
3545+
"min(n - k, k) must not exceed %lld",
3546+
LLONG_MAX);
34483547
goto error;
34493548
}
3549+
assert(ki >= 0);
34503550
}
3451-
Py_DECREF(factor);
3551+
3552+
result = perm_comb(n, (unsigned long long)ki, 1);
34523553

34533554
done:
34543555
Py_DECREF(n);
34553556
Py_DECREF(k);
34563557
return result;
34573558

34583559
error:
3459-
Py_XDECREF(factor);
3460-
Py_XDECREF(result);
34613560
Py_DECREF(n);
34623561
Py_DECREF(k);
34633562
return NULL;

0 commit comments

Comments
 (0)