@@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
3221
3221
}
3222
3222
3223
3223
3224
+ /* Number of permutations and combinations.
3225
+ * P(n, k) = n! / (n-k)!
3226
+ * C(n, k) = P(n, k) / k!
3227
+ */
3228
+
3229
+ /* Calculate C(n, k) for n in the 63-bit range. */
3230
+ static PyObject *
3231
+ perm_comb_small (unsigned long long n , unsigned long long k , int iscomb )
3232
+ {
3233
+ /* long long is at least 64 bit */
3234
+ static const unsigned long long fast_comb_limits [] = {
3235
+ 0 , ULLONG_MAX , 4294967296ULL , 3329022 , 102570 , 13467 , 3612 , 1449 , // 0-7
3236
+ 746 , 453 , 308 , 227 , 178 , 147 , 125 , 110 , // 8-15
3237
+ 99 , 90 , 84 , 79 , 75 , 72 , 69 , 68 , // 16-23
3238
+ 66 , 65 , 64 , 63 , 63 , 62 , 62 , 62 , // 24-31
3239
+ };
3240
+ static const unsigned long long fast_perm_limits [] = {
3241
+ 0 , ULLONG_MAX , 4294967296ULL , 2642246 , 65537 , 7133 , 1627 , 568 , // 0-7
3242
+ 259 , 142 , 88 , 61 , 45 , 36 , 30 , // 8-14
3243
+ };
3244
+
3245
+ if (k == 0 ) {
3246
+ return PyLong_FromLong (1 );
3247
+ }
3248
+
3249
+ /* For small enough n and k the result fits in the 64-bit range and can
3250
+ * be calculated without allocating intermediate PyLong objects. */
3251
+ if (iscomb
3252
+ ? (k < Py_ARRAY_LENGTH (fast_comb_limits )
3253
+ && n <= fast_comb_limits [k ])
3254
+ : (k < Py_ARRAY_LENGTH (fast_perm_limits )
3255
+ && n <= fast_perm_limits [k ]))
3256
+ {
3257
+ unsigned long long result = n ;
3258
+ if (iscomb ) {
3259
+ for (unsigned long long i = 1 ; i < k ;) {
3260
+ result *= -- n ;
3261
+ result /= ++ i ;
3262
+ }
3263
+ }
3264
+ else {
3265
+ for (unsigned long long i = 1 ; i < k ;) {
3266
+ result *= -- n ;
3267
+ ++ i ;
3268
+ }
3269
+ }
3270
+ return PyLong_FromUnsignedLongLong (result );
3271
+ }
3272
+
3273
+ /* For larger n use recursive formula. */
3274
+ /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
3275
+ unsigned long long j = k / 2 ;
3276
+ PyObject * a , * b ;
3277
+ a = perm_comb_small (n , j , iscomb );
3278
+ if (a == NULL ) {
3279
+ return NULL ;
3280
+ }
3281
+ b = perm_comb_small (n - j , k - j , iscomb );
3282
+ if (b == NULL ) {
3283
+ goto error ;
3284
+ }
3285
+ Py_SETREF (a , PyNumber_Multiply (a , b ));
3286
+ Py_DECREF (b );
3287
+ if (iscomb && a != NULL ) {
3288
+ b = perm_comb_small (k , j , 1 );
3289
+ if (b == NULL ) {
3290
+ goto error ;
3291
+ }
3292
+ Py_SETREF (a , PyNumber_FloorDivide (a , b ));
3293
+ Py_DECREF (b );
3294
+ }
3295
+ return a ;
3296
+
3297
+ error :
3298
+ Py_DECREF (a );
3299
+ return NULL ;
3300
+ }
3301
+
3302
+ /* Calculate P(n, k) or C(n, k) using recursive formulas.
3303
+ * It is more efficient than sequential multiplication thanks to
3304
+ * Karatsuba multiplication.
3305
+ */
3306
+ static PyObject *
3307
+ perm_comb (PyObject * n , unsigned long long k , int iscomb )
3308
+ {
3309
+ if (k == 0 ) {
3310
+ return PyLong_FromLong (1 );
3311
+ }
3312
+ if (k == 1 ) {
3313
+ Py_INCREF (n );
3314
+ return n ;
3315
+ }
3316
+
3317
+ /* P(n, k) = P(n, j) * P(n-j, k-j) */
3318
+ /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
3319
+ unsigned long long j = k / 2 ;
3320
+ PyObject * a , * b ;
3321
+ a = perm_comb (n , j , iscomb );
3322
+ if (a == NULL ) {
3323
+ return NULL ;
3324
+ }
3325
+ PyObject * t = PyLong_FromUnsignedLongLong (j );
3326
+ if (t == NULL ) {
3327
+ goto error ;
3328
+ }
3329
+ n = PyNumber_Subtract (n , t );
3330
+ Py_DECREF (t );
3331
+ if (n == NULL ) {
3332
+ goto error ;
3333
+ }
3334
+ b = perm_comb (n , k - j , iscomb );
3335
+ Py_DECREF (n );
3336
+ if (b == NULL ) {
3337
+ goto error ;
3338
+ }
3339
+ Py_SETREF (a , PyNumber_Multiply (a , b ));
3340
+ Py_DECREF (b );
3341
+ if (iscomb && a != NULL ) {
3342
+ b = perm_comb_small (k , j , 1 );
3343
+ if (b == NULL ) {
3344
+ goto error ;
3345
+ }
3346
+ Py_SETREF (a , PyNumber_FloorDivide (a , b ));
3347
+ Py_DECREF (b );
3348
+ }
3349
+ return a ;
3350
+
3351
+ error :
3352
+ Py_DECREF (a );
3353
+ return NULL ;
3354
+ }
3355
+
3224
3356
/*[clinic input]
3225
3357
math.perm
3226
3358
@@ -3244,9 +3376,9 @@ static PyObject *
3244
3376
math_perm_impl (PyObject * module , PyObject * n , PyObject * k )
3245
3377
/*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
3246
3378
{
3247
- PyObject * result = NULL , * factor = NULL ;
3379
+ PyObject * result = NULL ;
3248
3380
int overflow , cmp ;
3249
- long long i , factors ;
3381
+ long long ki , ni ;
3250
3382
3251
3383
if (k == Py_None ) {
3252
3384
return math_factorial (module , n );
@@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
3260
3392
Py_DECREF (n );
3261
3393
return NULL ;
3262
3394
}
3395
+ assert (PyLong_CheckExact (n ) && PyLong_CheckExact (k ));
3263
3396
3264
3397
if (Py_SIZE (n ) < 0 ) {
3265
3398
PyErr_SetString (PyExc_ValueError ,
@@ -3281,57 +3414,38 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
3281
3414
goto error ;
3282
3415
}
3283
3416
3284
- factors = PyLong_AsLongLongAndOverflow (k , & overflow );
3417
+ ki = PyLong_AsLongLongAndOverflow (k , & overflow );
3418
+ assert (overflow >= 0 && !PyErr_Occurred ());
3285
3419
if (overflow > 0 ) {
3286
3420
PyErr_Format (PyExc_OverflowError ,
3287
3421
"k must not exceed %lld" ,
3288
3422
LLONG_MAX );
3289
3423
goto error ;
3290
3424
}
3291
- else if (factors == -1 ) {
3292
- /* k is nonnegative, so a return value of -1 can only indicate error */
3293
- goto error ;
3294
- }
3425
+ assert (ki >= 0 );
3295
3426
3296
- if (factors == 0 ) {
3297
- result = PyLong_FromLong (1 );
3298
- goto done ;
3427
+ ni = PyLong_AsLongLongAndOverflow (n , & overflow );
3428
+ assert (overflow >= 0 && !PyErr_Occurred ());
3429
+ if (!overflow && ki > 1 ) {
3430
+ assert (ni >= 0 );
3431
+ result = perm_comb_small ((unsigned long long )ni ,
3432
+ (unsigned long long )ki , 0 );
3299
3433
}
3300
-
3301
- result = n ;
3302
- Py_INCREF (result );
3303
- if (factors == 1 ) {
3304
- goto done ;
3305
- }
3306
-
3307
- factor = Py_NewRef (n );
3308
- PyObject * one = _PyLong_GetOne (); // borrowed ref
3309
- for (i = 1 ; i < factors ; ++ i ) {
3310
- Py_SETREF (factor , PyNumber_Subtract (factor , one ));
3311
- if (factor == NULL ) {
3312
- goto error ;
3313
- }
3314
- Py_SETREF (result , PyNumber_Multiply (result , factor ));
3315
- if (result == NULL ) {
3316
- goto error ;
3317
- }
3434
+ else {
3435
+ result = perm_comb (n , (unsigned long long )ki , 0 );
3318
3436
}
3319
- Py_DECREF (factor );
3320
3437
3321
3438
done :
3322
3439
Py_DECREF (n );
3323
3440
Py_DECREF (k );
3324
3441
return result ;
3325
3442
3326
3443
error :
3327
- Py_XDECREF (factor );
3328
- Py_XDECREF (result );
3329
3444
Py_DECREF (n );
3330
3445
Py_DECREF (k );
3331
3446
return NULL ;
3332
3447
}
3333
3448
3334
-
3335
3449
/*[clinic input]
3336
3450
math.comb
3337
3451
@@ -3357,9 +3471,9 @@ static PyObject *
3357
3471
math_comb_impl (PyObject * module , PyObject * n , PyObject * k )
3358
3472
/*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
3359
3473
{
3360
- PyObject * result = NULL , * factor = NULL , * temp ;
3474
+ PyObject * result = NULL , * temp ;
3361
3475
int overflow , cmp ;
3362
- long long i , factors ;
3476
+ long long ki , ni ;
3363
3477
3364
3478
n = PyNumber_Index (n );
3365
3479
if (n == NULL ) {
@@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
3370
3484
Py_DECREF (n );
3371
3485
return NULL ;
3372
3486
}
3487
+ assert (PyLong_CheckExact (n ) && PyLong_CheckExact (k ));
3373
3488
3374
3489
if (Py_SIZE (n ) < 0 ) {
3375
3490
PyErr_SetString (PyExc_ValueError ,
@@ -3382,82 +3497,66 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
3382
3497
goto error ;
3383
3498
}
3384
3499
3385
- /* k = min(k, n - k) */
3386
- temp = PyNumber_Subtract (n , k );
3387
- if (temp == NULL ) {
3388
- goto error ;
3389
- }
3390
- if (Py_SIZE (temp ) < 0 ) {
3391
- Py_DECREF (temp );
3392
- result = PyLong_FromLong (0 );
3393
- goto done ;
3394
- }
3395
- cmp = PyObject_RichCompareBool (temp , k , Py_LT );
3396
- if (cmp > 0 ) {
3397
- Py_SETREF (k , temp );
3500
+ ni = PyLong_AsLongLongAndOverflow (n , & overflow );
3501
+ assert (overflow >= 0 && !PyErr_Occurred ());
3502
+ if (!overflow ) {
3503
+ assert (ni >= 0 );
3504
+ ki = PyLong_AsLongLongAndOverflow (k , & overflow );
3505
+ assert (overflow >= 0 && !PyErr_Occurred ());
3506
+ if (overflow || ki > ni ) {
3507
+ result = PyLong_FromLong (0 );
3508
+ goto done ;
3509
+ }
3510
+ assert (ki >= 0 );
3511
+ ki = Py_MIN (ki , ni - ki );
3512
+ if (ki > 1 ) {
3513
+ result = perm_comb_small ((unsigned long long )ni ,
3514
+ (unsigned long long )ki , 1 );
3515
+ goto done ;
3516
+ }
3517
+ /* For k == 1 just return the original n in perm_comb(). */
3398
3518
}
3399
3519
else {
3400
- Py_DECREF (temp );
3401
- if (cmp < 0 ) {
3520
+ /* k = min(k, n - k) */
3521
+ temp = PyNumber_Subtract (n , k );
3522
+ if (temp == NULL ) {
3402
3523
goto error ;
3403
3524
}
3404
- }
3405
-
3406
- factors = PyLong_AsLongLongAndOverflow (k , & overflow );
3407
- if (overflow > 0 ) {
3408
- PyErr_Format (PyExc_OverflowError ,
3409
- "min(n - k, k) must not exceed %lld" ,
3410
- LLONG_MAX );
3411
- goto error ;
3412
- }
3413
- if (factors == -1 ) {
3414
- /* k is nonnegative, so a return value of -1 can only indicate error */
3415
- goto error ;
3416
- }
3417
-
3418
- if (factors == 0 ) {
3419
- result = PyLong_FromLong (1 );
3420
- goto done ;
3421
- }
3422
-
3423
- result = n ;
3424
- Py_INCREF (result );
3425
- if (factors == 1 ) {
3426
- goto done ;
3427
- }
3428
-
3429
- factor = Py_NewRef (n );
3430
- PyObject * one = _PyLong_GetOne (); // borrowed ref
3431
- for (i = 1 ; i < factors ; ++ i ) {
3432
- Py_SETREF (factor , PyNumber_Subtract (factor , one ));
3433
- if (factor == NULL ) {
3434
- goto error ;
3525
+ if (Py_SIZE (temp ) < 0 ) {
3526
+ Py_DECREF (temp );
3527
+ result = PyLong_FromLong (0 );
3528
+ goto done ;
3435
3529
}
3436
- Py_SETREF ( result , PyNumber_Multiply ( result , factor ) );
3437
- if (result == NULL ) {
3438
- goto error ;
3530
+ cmp = PyObject_RichCompareBool ( temp , k , Py_LT );
3531
+ if (cmp > 0 ) {
3532
+ Py_SETREF ( k , temp ) ;
3439
3533
}
3440
-
3441
- temp = PyLong_FromUnsignedLongLong ((unsigned long long )i + 1 );
3442
- if (temp == NULL ) {
3443
- goto error ;
3534
+ else {
3535
+ Py_DECREF (temp );
3536
+ if (cmp < 0 ) {
3537
+ goto error ;
3538
+ }
3444
3539
}
3445
- Py_SETREF (result , PyNumber_FloorDivide (result , temp ));
3446
- Py_DECREF (temp );
3447
- if (result == NULL ) {
3540
+
3541
+ ki = PyLong_AsLongLongAndOverflow (k , & overflow );
3542
+ assert (overflow >= 0 && !PyErr_Occurred ());
3543
+ if (overflow ) {
3544
+ PyErr_Format (PyExc_OverflowError ,
3545
+ "min(n - k, k) must not exceed %lld" ,
3546
+ LLONG_MAX );
3448
3547
goto error ;
3449
3548
}
3549
+ assert (ki >= 0 );
3450
3550
}
3451
- Py_DECREF (factor );
3551
+
3552
+ result = perm_comb (n , (unsigned long long )ki , 1 );
3452
3553
3453
3554
done :
3454
3555
Py_DECREF (n );
3455
3556
Py_DECREF (k );
3456
3557
return result ;
3457
3558
3458
3559
error :
3459
- Py_XDECREF (factor );
3460
- Py_XDECREF (result );
3461
3560
Py_DECREF (n );
3462
3561
Py_DECREF (k );
3463
3562
return NULL ;
0 commit comments