diff --git a/library/core/src/char/decode.rs b/library/core/src/char/decode.rs index 5dd8c5ef78941..8b9f979b573f7 100644 --- a/library/core/src/char/decode.rs +++ b/library/core/src/char/decode.rs @@ -120,9 +120,34 @@ impl> Iterator for DecodeUtf16 { #[inline] fn size_hint(&self) -> (usize, Option) { let (low, high) = self.iter.size_hint(); - // we could be entirely valid surrogates (2 elements per - // char), or entirely non-surrogates (1 element per char) - (low / 2, high) + + let (low_buf, high_buf) = match self.buf { + // buf is empty, no additional elements from it. + None => (0, 0), + // `u` is a non surrogate, so it's always an additional character. + Some(u) if u < 0xD800 || 0xDFFF < u => (1, 1), + // `u` is a leading surrogate (it can never be a trailing surrogate and + // it's a surrogate due to the previous branch) and `self.iter` is empty. + // + // `u` can't be paired, since the `self.iter` is empty, + // so it will always become an additional element (error). + Some(_u) if high == Some(0) => (1, 1), + // `u` is a leading surrogate and `iter` may be non-empty. + // + // `u` can either pair with a trailing surrogate, in which case no additional elements + // are produced, or it can become an error, in which case it's an additional character (error). + Some(_u) => (0, 1), + }; + + // `self.iter` could contain entirely valid surrogates (2 elements per + // char), or entirely non-surrogates (1 element per char). + // + // On odd lower bound, at least one element must stay unpaired + // (with other elements from `self.iter`), so we round up. + let low = low.div_ceil(2) + low_buf; + let high = high.and_then(|h| h.checked_add(high_buf)); + + (low, high) } } diff --git a/library/core/tests/char.rs b/library/core/tests/char.rs index 2b857a6591929..4c899b6eb43d0 100644 --- a/library/core/tests/char.rs +++ b/library/core/tests/char.rs @@ -308,6 +308,33 @@ fn test_decode_utf16() { check(&[0xD800, 0], &[Err(0xD800), Ok('\0')]); } +#[test] +fn test_decode_utf16_size_hint() { + fn check(s: &[u16]) { + let mut iter = char::decode_utf16(s.iter().cloned()); + + loop { + let count = iter.clone().count(); + let (lower, upper) = iter.size_hint(); + + assert!( + lower <= count && count <= upper.unwrap(), + "lower = {lower}, count = {count}, upper = {upper:?}" + ); + + if let None = iter.next() { + break; + } + } + } + + check(&[0xD800, 0xD800, 0xDC00]); + check(&[0xD800, 0xD800, 0x0]); + check(&[0xD800, 0x41, 0x42]); + check(&[0xD800, 0]); + check(&[0xD834, 0x006d]); +} + #[test] fn ed_iterator_specializations() { // Check counting