|
5 | 5 | // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
|
6 | 6 | // option. This file may not be copied, modified, or distributed
|
7 | 7 | // except according to those terms.
|
| 8 | +#![allow(non_camel_case_types)] |
8 | 9 |
|
9 | 10 | use crate::Error;
|
10 |
| -use core::{ffi::c_void, mem::MaybeUninit, num::NonZeroU32, ptr}; |
| 11 | +use core::{ |
| 12 | + convert::TryInto, |
| 13 | + ffi::{c_long, c_void}, |
| 14 | + mem::MaybeUninit, |
| 15 | + num::NonZeroU32, |
| 16 | + ptr, |
| 17 | +}; |
11 | 18 |
|
| 19 | +// same as Rust's libstd. |
| 20 | +type BCRYPT_ALG_HANDLE = *mut c_void; |
| 21 | +type NTSTATUS = c_long; |
| 22 | + |
| 23 | +// "RNG\0" |
| 24 | +const BCRYPT_RNG_ALGORITHM: &[u16] = &[b'R' as u16, b'N' as u16, b'G' as u16, 0]; |
12 | 25 | const BCRYPT_USE_SYSTEM_PREFERRED_RNG: u32 = 0x00000002;
|
13 | 26 |
|
| 27 | +// Equivalent to the `NT_SUCCESS` C preprocessor macro. |
| 28 | +// See: https://docs.microsoft.com/en-us/windows-hardware/drivers/kernel/using-ntstatus-values |
| 29 | +fn nt_success(status: NTSTATUS) -> bool { |
| 30 | + status >= 0 |
| 31 | +} |
| 32 | + |
| 33 | +/// Extract error code and turn into an `Error` |
| 34 | +fn nt_error(status: NTSTATUS) -> Error { |
| 35 | + // We zeroize the highest bit, so the error code will reside |
| 36 | + // inside the range designated for OS codes. |
| 37 | + let code = status as u32 ^ (1 << 31); |
| 38 | + // SAFETY: the second highest bit is always equal to one, |
| 39 | + // so it's impossible to get zero. Unfortunately the type |
| 40 | + // system does not have a way to express this yet. |
| 41 | + let code = unsafe { NonZeroU32::new_unchecked(code) }; |
| 42 | + Error::from(code) |
| 43 | +} |
| 44 | + |
14 | 45 | #[link(name = "bcrypt")]
|
15 | 46 | extern "system" {
|
16 | 47 | fn BCryptGenRandom(
|
17 |
| - hAlgorithm: *mut c_void, |
| 48 | + hAlgorithm: BCRYPT_ALG_HANDLE, |
18 | 49 | pBuffer: *mut u8,
|
19 | 50 | cbBuffer: u32,
|
20 | 51 | dwFlags: u32,
|
21 |
| - ) -> u32; |
| 52 | + ) -> NTSTATUS; |
| 53 | + pub fn BCryptOpenAlgorithmProvider( |
| 54 | + phalgorithm: *mut BCRYPT_ALG_HANDLE, |
| 55 | + pszAlgId: *const u16, |
| 56 | + pszimplementation: *const u16, |
| 57 | + dwflags: u32, |
| 58 | + ) -> NTSTATUS; |
| 59 | + pub fn BCryptCloseAlgorithmProvider(hAlgorithm: BCRYPT_ALG_HANDLE, dwFlags: u32) -> NTSTATUS; |
22 | 60 | }
|
23 | 61 |
|
24 | 62 | pub fn getrandom_inner(dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> {
|
25 | 63 | // Prevent overflow of u32
|
26 | 64 | for chunk in dest.chunks_mut(u32::max_value() as usize) {
|
27 |
| - // BCryptGenRandom was introduced in Windows Vista |
| 65 | + if let Err(_) = Rng::SYSTEM.random(chunk) { |
| 66 | + fallback_rng(chunk)?; |
| 67 | + } |
| 68 | + } |
| 69 | + Ok(()) |
| 70 | +} |
| 71 | + |
| 72 | +struct Rng { |
| 73 | + algorithm: BCRYPT_ALG_HANDLE, |
| 74 | + flags: u32, |
| 75 | +} |
| 76 | + |
| 77 | +impl Rng { |
| 78 | + const SYSTEM: Self = unsafe { Self::new(ptr::null_mut(), BCRYPT_USE_SYSTEM_PREFERRED_RNG) }; |
| 79 | + |
| 80 | + /// Create the RNG from an existing algorithm handle. |
| 81 | + /// |
| 82 | + /// # Safety |
| 83 | + /// |
| 84 | + /// The handle must either be null or a valid algorithm handle. |
| 85 | + const unsafe fn new(algorithm: BCRYPT_ALG_HANDLE, flags: u32) -> Self { |
| 86 | + Self { algorithm, flags } |
| 87 | + } |
| 88 | + |
| 89 | + /// Open a handle to the RNG algorithm. |
| 90 | + fn open() -> Result<Self, Error> { |
| 91 | + use core::sync::atomic::AtomicPtr; |
| 92 | + use core::sync::atomic::Ordering::{Acquire, Release}; |
| 93 | + |
| 94 | + // An atomic is used so we don't need to reopen the handle every time. |
| 95 | + static HANDLE: AtomicPtr<c_void> = AtomicPtr::new(ptr::null_mut()); |
| 96 | + |
| 97 | + let mut handle = HANDLE.load(Acquire); |
| 98 | + if handle.is_null() { |
| 99 | + let status = unsafe { |
| 100 | + BCryptOpenAlgorithmProvider( |
| 101 | + &mut handle, |
| 102 | + BCRYPT_RNG_ALGORITHM.as_ptr(), |
| 103 | + ptr::null(), |
| 104 | + 0, |
| 105 | + ) |
| 106 | + }; |
| 107 | + if nt_success(status) { |
| 108 | + // If another thread opens a handle first then use that handle instead. |
| 109 | + let result = HANDLE.compare_exchange(ptr::null_mut(), handle, Release, Acquire); |
| 110 | + if let Err(previous_handle) = result { |
| 111 | + // Close our handle and return the previous one. |
| 112 | + unsafe { BCryptCloseAlgorithmProvider(handle, 0) }; |
| 113 | + handle = previous_handle; |
| 114 | + } |
| 115 | + Ok(unsafe { Self::new(handle, 0) }) |
| 116 | + } else { |
| 117 | + Err(nt_error(status)) |
| 118 | + } |
| 119 | + } else { |
| 120 | + Ok(unsafe { Self::new(handle, 0) }) |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + fn random(&self, dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> { |
| 125 | + let len: u32 = dest.len().try_into().unwrap(); |
| 126 | + // SAFETY: dest is valid, writable buffer of length len |
28 | 127 | let ret = unsafe {
|
29 | 128 | BCryptGenRandom(
|
30 |
| - ptr::null_mut(), |
31 |
| - chunk.as_mut_ptr() as *mut u8, |
32 |
| - chunk.len() as u32, |
33 |
| - BCRYPT_USE_SYSTEM_PREFERRED_RNG, |
| 129 | + self.algorithm, |
| 130 | + dest.as_mut_ptr() as *mut u8, |
| 131 | + len, |
| 132 | + self.flags, |
34 | 133 | )
|
35 | 134 | };
|
36 |
| - // NTSTATUS codes use the two highest bits for severity status. |
37 |
| - if ret >> 30 == 0b11 { |
38 |
| - // We zeroize the highest bit, so the error code will reside |
39 |
| - // inside the range designated for OS codes. |
40 |
| - let code = ret ^ (1 << 31); |
41 |
| - // SAFETY: the second highest bit is always equal to one, |
42 |
| - // so it's impossible to get zero. Unfortunately the type |
43 |
| - // system does not have a way to express this yet. |
44 |
| - let code = unsafe { NonZeroU32::new_unchecked(code) }; |
45 |
| - return Err(Error::from(code)); |
| 135 | + |
| 136 | + if nt_success(ret) { |
| 137 | + return Ok(()); |
46 | 138 | }
|
| 139 | + |
| 140 | + Err(nt_error(ret)) |
47 | 141 | }
|
48 |
| - Ok(()) |
| 142 | +} |
| 143 | + |
| 144 | +/// Generate random numbers using the fallback RNG function |
| 145 | +#[inline(never)] |
| 146 | +fn fallback_rng(dest: &mut [MaybeUninit<u8>]) -> Result<(), Error> { |
| 147 | + Rng::open()?.random(dest) |
49 | 148 | }
|
0 commit comments