Skip to content

Commit 6f67832

Browse files
committed
Make the RNG fall back to using an algorithm handle if BCryptGenRandom fails
Based on rust-lang/rust#102044
1 parent 7f73e3c commit 6f67832

File tree

1 file changed

+118
-19
lines changed

1 file changed

+118
-19
lines changed

src/windows.rs

Lines changed: 118 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,144 @@
55
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
8+
#![allow(non_camel_case_types)]
89

910
use crate::Error;
10-
use core::{ffi::c_void, mem::MaybeUninit, num::NonZeroU32, ptr};
11+
use core::{
12+
convert::TryInto,
13+
ffi::{c_long, c_void},
14+
mem::MaybeUninit,
15+
num::NonZeroU32,
16+
ptr,
17+
};
1118

19+
// same as Rust's libstd.
20+
type BCRYPT_ALG_HANDLE = *mut c_void;
21+
type NTSTATUS = c_long;
22+
23+
// "RNG\0"
24+
const BCRYPT_RNG_ALGORITHM: &[u16] = &[b'R' as u16, b'N' as u16, b'G' as u16, 0];
1225
const BCRYPT_USE_SYSTEM_PREFERRED_RNG: u32 = 0x00000002;
1326

27+
// Equivalent to the `NT_SUCCESS` C preprocessor macro.
28+
// See: https://docs.microsoft.com/en-us/windows-hardware/drivers/kernel/using-ntstatus-values
29+
fn nt_success(status: NTSTATUS) -> bool {
30+
status >= 0
31+
}
32+
33+
/// Extract error code and turn into an `Error`
34+
fn nt_error(status: NTSTATUS) -> Error {
35+
// We zeroize the highest bit, so the error code will reside
36+
// inside the range designated for OS codes.
37+
let code = status as u32 ^ (1 << 31);
38+
// SAFETY: the second highest bit is always equal to one,
39+
// so it's impossible to get zero. Unfortunately the type
40+
// system does not have a way to express this yet.
41+
let code = unsafe { NonZeroU32::new_unchecked(code) };
42+
Error::from(code)
43+
}
44+
1445
#[link(name = "bcrypt")]
1546
extern "system" {
1647
fn BCryptGenRandom(
17-
hAlgorithm: *mut c_void,
48+
hAlgorithm: BCRYPT_ALG_HANDLE,
1849
pBuffer: *mut u8,
1950
cbBuffer: u32,
2051
dwFlags: u32,
21-
) -> u32;
52+
) -> NTSTATUS;
53+
pub fn BCryptOpenAlgorithmProvider(
54+
phalgorithm: *mut BCRYPT_ALG_HANDLE,
55+
pszAlgId: *const u16,
56+
pszimplementation: *const u16,
57+
dwflags: u32,
58+
) -> NTSTATUS;
59+
pub fn BCryptCloseAlgorithmProvider(hAlgorithm: BCRYPT_ALG_HANDLE, dwFlags: u32) -> NTSTATUS;
2260
}
2361

2462
pub fn getrandom_inner(dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> {
2563
// Prevent overflow of u32
2664
for chunk in dest.chunks_mut(u32::max_value() as usize) {
27-
// BCryptGenRandom was introduced in Windows Vista
65+
if let Err(_) = Rng::SYSTEM.random(chunk) {
66+
fallback_rng(chunk)?;
67+
}
68+
}
69+
Ok(())
70+
}
71+
72+
struct Rng {
73+
algorithm: BCRYPT_ALG_HANDLE,
74+
flags: u32,
75+
}
76+
77+
impl Rng {
78+
const SYSTEM: Self = unsafe { Self::new(ptr::null_mut(), BCRYPT_USE_SYSTEM_PREFERRED_RNG) };
79+
80+
/// Create the RNG from an existing algorithm handle.
81+
///
82+
/// # Safety
83+
///
84+
/// The handle must either be null or a valid algorithm handle.
85+
const unsafe fn new(algorithm: BCRYPT_ALG_HANDLE, flags: u32) -> Self {
86+
Self { algorithm, flags }
87+
}
88+
89+
/// Open a handle to the RNG algorithm.
90+
fn open() -> Result<Self, Error> {
91+
use core::sync::atomic::AtomicPtr;
92+
use core::sync::atomic::Ordering::{Acquire, Release};
93+
94+
// An atomic is used so we don't need to reopen the handle every time.
95+
static HANDLE: AtomicPtr<c_void> = AtomicPtr::new(ptr::null_mut());
96+
97+
let mut handle = HANDLE.load(Acquire);
98+
if handle.is_null() {
99+
let status = unsafe {
100+
BCryptOpenAlgorithmProvider(
101+
&mut handle,
102+
BCRYPT_RNG_ALGORITHM.as_ptr(),
103+
ptr::null(),
104+
0,
105+
)
106+
};
107+
if nt_success(status) {
108+
// If another thread opens a handle first then use that handle instead.
109+
let result = HANDLE.compare_exchange(ptr::null_mut(), handle, Release, Acquire);
110+
if let Err(previous_handle) = result {
111+
// Close our handle and return the previous one.
112+
unsafe { BCryptCloseAlgorithmProvider(handle, 0) };
113+
handle = previous_handle;
114+
}
115+
Ok(unsafe { Self::new(handle, 0) })
116+
} else {
117+
Err(nt_error(status))
118+
}
119+
} else {
120+
Ok(unsafe { Self::new(handle, 0) })
121+
}
122+
}
123+
124+
fn random(&self, dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> {
125+
let len: u32 = dest.len().try_into().unwrap();
126+
// SAFETY: dest is valid, writable buffer of length len
28127
let ret = unsafe {
29128
BCryptGenRandom(
30-
ptr::null_mut(),
31-
chunk.as_mut_ptr() as *mut u8,
32-
chunk.len() as u32,
33-
BCRYPT_USE_SYSTEM_PREFERRED_RNG,
129+
self.algorithm,
130+
dest.as_mut_ptr() as *mut u8,
131+
len,
132+
self.flags,
34133
)
35134
};
36-
// NTSTATUS codes use the two highest bits for severity status.
37-
if ret >> 30 == 0b11 {
38-
// We zeroize the highest bit, so the error code will reside
39-
// inside the range designated for OS codes.
40-
let code = ret ^ (1 << 31);
41-
// SAFETY: the second highest bit is always equal to one,
42-
// so it's impossible to get zero. Unfortunately the type
43-
// system does not have a way to express this yet.
44-
let code = unsafe { NonZeroU32::new_unchecked(code) };
45-
return Err(Error::from(code));
135+
136+
if nt_success(ret) {
137+
return Ok(());
46138
}
139+
140+
Err(nt_error(ret))
47141
}
48-
Ok(())
142+
}
143+
144+
/// Generate random numbers using the fallback RNG function
145+
#[inline(never)]
146+
fn fallback_rng(dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> {
147+
Rng::open()?.random(dest)
49148
}

0 commit comments

Comments
 (0)