Skip to content
This repository was archived by the owner on Apr 20, 2025. It is now read-only.

Commit 88418f0

Browse files
committed
Fix threading issue introduced in 4.7
Computing the blinding factor and its inverse was done in a thread-unsafe manner. Locking the computation & update of the blinding factors, and passing these around in frame- and stack-bound data, solves this. This fixes part of the issues reported in #173, but there is more going on in that particular report.
1 parent 3af4e65 commit 88418f0

File tree

3 files changed

+55
-38
lines changed

3 files changed

+55
-38
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Python-RSA changelog
22

3+
## Version 4.7.1 - in development
4+
5+
- Fix threading issue introduced in 4.7 ([#173](https://github.com/sybrenstuvel/python-rsa/issues/173)
6+
37
## Version 4.7 - released 2021-01-10
48

59
- Fix [#165](https://github.com/sybrenstuvel/python-rsa/issues/165):

rsa/key.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"""
3333

3434
import logging
35+
import threading
3536
import typing
3637
import warnings
3738

@@ -49,7 +50,7 @@
4950
class AbstractKey:
5051
"""Abstract superclass for private and public keys."""
5152

52-
__slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse')
53+
__slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse', 'mutex')
5354

5455
def __init__(self, n: int, e: int) -> None:
5556
self.n = n
@@ -58,6 +59,10 @@ def __init__(self, n: int, e: int) -> None:
5859
# These will be computed properly on the first call to blind().
5960
self.blindfac = self.blindfac_inverse = -1
6061

62+
# Used to protect updates to the blinding factor in multi-threaded
63+
# environments.
64+
self.mutex = threading.Lock()
65+
6166
@classmethod
6267
def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey':
6368
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
@@ -148,36 +153,33 @@ def save_pkcs1(self, format: str = 'PEM') -> bytes:
148153
method = self._assert_format_exists(format, methods)
149154
return method()
150155

151-
def blind(self, message: int) -> int:
152-
"""Performs blinding on the message using random number 'r'.
156+
def blind(self, message: int) -> typing.Tuple[int, int]:
157+
"""Performs blinding on the message.
153158
154159
:param message: the message, as integer, to blind.
155-
:type message: int
156160
:param r: the random number to blind with.
157-
:type r: int
158-
:return: the blinded message.
159-
:rtype: int
161+
:return: tuple (the blinded message, the inverse of the used blinding factor)
160162
161163
The blinding is such that message = unblind(decrypt(blind(encrypt(message))).
162164
163165
See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
164166
"""
165-
self._update_blinding_factor()
166-
return (message * pow(self.blindfac, self.e, self.n)) % self.n
167+
blindfac, blindfac_inverse = self._update_blinding_factor()
168+
blinded = (message * pow(blindfac, self.e, self.n)) % self.n
169+
return blinded, blindfac_inverse
167170

168-
def unblind(self, blinded: int) -> int:
169-
"""Performs blinding on the message using random number 'r'.
171+
def unblind(self, blinded: int, blindfac_inverse: int) -> int:
172+
"""Performs blinding on the message using random number 'blindfac_inverse'.
170173
171174
:param blinded: the blinded message, as integer, to unblind.
172-
:param r: the random number to unblind with.
175+
:param blindfac: the factor to unblind with.
173176
:return: the original message.
174177
175178
The blinding is such that message = unblind(decrypt(blind(encrypt(message))).
176179
177180
See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
178181
"""
179-
180-
return (self.blindfac_inverse * blinded) % self.n
182+
return (blindfac_inverse * blinded) % self.n
181183

182184
def _initial_blinding_factor(self) -> int:
183185
for _ in range(1000):
@@ -186,18 +188,29 @@ def _initial_blinding_factor(self) -> int:
186188
return blind_r
187189
raise RuntimeError('unable to find blinding factor')
188190

189-
def _update_blinding_factor(self):
190-
if self.blindfac < 0:
191-
# Compute initial blinding factor, which is rather slow to do.
192-
self.blindfac = self._initial_blinding_factor()
193-
self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n)
194-
else:
195-
# Reuse previous blinding factor as per section 9 of 'A Timing
196-
# Attack against RSA with the Chinese Remainder Theorem' by Werner
197-
# Schindler.
198-
# See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf
199-
self.blindfac = pow(self.blindfac, 2, self.n)
200-
self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n)
191+
def _update_blinding_factor(self) -> typing.Tuple[int, int]:
192+
"""Update blinding factors.
193+
194+
Computing a blinding factor is expensive, so instead this function
195+
does this once, then updates the blinding factor as per section 9
196+
of 'A Timing Attack against RSA with the Chinese Remainder Theorem'
197+
by Werner Schindler.
198+
See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf
199+
200+
:return: the new blinding factor and its inverse.
201+
"""
202+
203+
with self.mutex:
204+
if self.blindfac < 0:
205+
# Compute initial blinding factor, which is rather slow to do.
206+
self.blindfac = self._initial_blinding_factor()
207+
self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n)
208+
else:
209+
# Reuse previous blinding factor.
210+
self.blindfac = pow(self.blindfac, 2, self.n)
211+
self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n)
212+
213+
return self.blindfac, self.blindfac_inverse
201214

202215
class PublicKey(AbstractKey):
203216
"""Represents a public RSA key.
@@ -446,9 +459,10 @@ def blinded_decrypt(self, encrypted: int) -> int:
446459
:rtype: int
447460
"""
448461

449-
blinded = self.blind(encrypted) # blind before decrypting
462+
# Blinding and un-blinding should be using the same factor
463+
blinded, blindfac_inverse = self.blind(encrypted)
450464
decrypted = rsa.core.decrypt_int(blinded, self.d, self.n)
451-
return self.unblind(decrypted)
465+
return self.unblind(decrypted, blindfac_inverse)
452466

453467
def blinded_encrypt(self, message: int) -> int:
454468
"""Encrypts the message using blinding to prevent side-channel attacks.
@@ -460,9 +474,9 @@ def blinded_encrypt(self, message: int) -> int:
460474
:rtype: int
461475
"""
462476

463-
blinded = self.blind(message) # blind before encrypting
477+
blinded, blindfac_inverse = self.blind(message)
464478
encrypted = rsa.core.encrypt_int(blinded, self.d, self.n)
465-
return self.unblind(encrypted)
479+
return self.unblind(encrypted, blindfac_inverse)
466480

467481
@classmethod
468482
def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey':

tests/test_key.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,19 @@ def test_blinding(self):
2121
message = 12345
2222
encrypted = rsa.core.encrypt_int(message, pk.e, pk.n)
2323

24-
blinded_1 = pk.blind(encrypted) # blind before decrypting
24+
blinded_1, unblind_1 = pk.blind(encrypted) # blind before decrypting
2525
decrypted = rsa.core.decrypt_int(blinded_1, pk.d, pk.n)
26-
unblinded_1 = pk.unblind(decrypted)
26+
unblinded_1 = pk.unblind(decrypted, unblind_1)
2727

2828
self.assertEqual(unblinded_1, message)
2929

3030
# Re-blinding should use a different blinding factor.
31-
blinded_2 = pk.blind(encrypted) # blind before decrypting
31+
blinded_2, unblind_2 = pk.blind(encrypted) # blind before decrypting
3232
self.assertNotEqual(blinded_1, blinded_2)
3333

3434
# The unblinding should still work, though.
3535
decrypted = rsa.core.decrypt_int(blinded_2, pk.d, pk.n)
36-
unblinded_2 = pk.unblind(decrypted)
36+
unblinded_2 = pk.unblind(decrypted, unblind_2)
3737
self.assertEqual(unblinded_2, message)
3838

3939

@@ -69,10 +69,9 @@ def getprime(_):
6969
# This exponent will cause two other primes to be generated.
7070
exponent = 136407
7171

72-
(p, q, e, d) = rsa.key.gen_keys(64,
73-
accurate=False,
74-
getprime_func=getprime,
75-
exponent=exponent)
72+
(p, q, e, d) = rsa.key.gen_keys(
73+
64, accurate=False, getprime_func=getprime, exponent=exponent
74+
)
7675
self.assertEqual(39317, p)
7776
self.assertEqual(33107, q)
7877

0 commit comments

Comments
 (0)