diff --git a/msgpacker-derive/Cargo.toml b/msgpacker-derive/Cargo.toml index 7ccc16a..0c205f2 100644 --- a/msgpacker-derive/Cargo.toml +++ b/msgpacker-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "msgpacker-derive" -version = "0.3.1" +version = "0.4.0" authors = ["Victor Lopez "] categories = ["compression", "encoding", "parser-implementations"] edition = "2021" diff --git a/msgpacker-derive/src/lib.rs b/msgpacker-derive/src/lib.rs index 401cdd0..5977a77 100644 --- a/msgpacker-derive/src/lib.rs +++ b/msgpacker-derive/src/lib.rs @@ -36,20 +36,62 @@ fn contains_attribute(field: &Field, name: &str) -> bool { fn impl_fields_named(name: Ident, f: FieldsNamed) -> impl Into { let mut values: Punctuated = Punctuated::new(); + let field_len = f.named.len(); let block_packable: Block = parse_quote! { { let mut n = 0; + n += ::msgpacker::derive_util::get_array_info(buf, #field_len); } }; let block_unpackable: Block = parse_quote! { { let mut n = 0; + let expected_len = #field_len; + + let format = ::msgpacker::derive_util::take_byte(&mut buf)?; + + let (header_bytes, actual_len) = match format { + 0x90..=0x9f => (1, (format & 0x0f) as usize), + ::msgpacker::derive_util::Format::ARRAY16 => { + let len = ::msgpacker::derive_util::take_num(&mut buf, u16::from_be_bytes)? as usize; + (3, len) + } + ::msgpacker::derive_util::Format::ARRAY32 => { + let len = ::msgpacker::derive_util::take_num(&mut buf, u32::from_be_bytes)? as usize; + (5, len) + } + _ => return Err(::msgpacker::Error::UnexpectedFormatTag.into()), + }; + + if actual_len != expected_len { + return Err(::msgpacker::Error::UnexpectedStructLength.into()); + } + + n += header_bytes; } }; let block_unpackable_iter: Block = parse_quote! { { let mut bytes = bytes.into_iter(); let mut n = 0; + let expected_len = #field_len; + let format = ::msgpacker::derive_util::take_byte_iter(bytes.by_ref())?; + let (header_bytes, actual_len) = match format { + 0x90..=0x9f => (1, (format & 0x0f) as usize), + ::msgpacker::derive_util::Format::ARRAY16 => { + let len = ::msgpacker::derive_util::take_num_iter(&mut bytes, u16::from_be_bytes)? as usize; + (3, len) + } + ::msgpacker::derive_util::Format::ARRAY32 => { + let len = ::msgpacker::derive_util::take_num_iter(&mut bytes, u16::from_be_bytes)? as usize; + (5, len) + } + _ => return Err(::msgpacker::Error::UnexpectedFormatTag.into()), + }; + if actual_len != expected_len { + return Err(::msgpacker::Error::UnexpectedStructLength.into()); + } + n += header_bytes; } }; @@ -103,7 +145,26 @@ fn impl_fields_named(name: Ident, f: FieldsNamed) -> impl Into { t })?; }); - } else if contains_attribute(&field, "array") || is_vec && !is_vec_u8 { + } else if contains_attribute(&field, "binary") && is_vec_u8 { + block_packable.stmts.push(parse_quote! { + n += ::msgpacker::pack_binary(buf, &self.#ident); + }); + + block_unpackable.stmts.push(parse_quote! { + let #ident = ::msgpacker::unpack_bytes(buf).map(|(nv, t)| { + n += nv; + buf = &buf[nv..]; + t.to_vec() + })?; + }); + + block_unpackable_iter.stmts.push(parse_quote! { + let #ident = ::msgpacker::unpack_bytes_iter(bytes.by_ref()).map(|(nv, t)| { + n += nv; + t + })?; + }); + } else if contains_attribute(&field, "array") || is_vec { block_packable.stmts.push(parse_quote! { n += ::msgpacker::pack_array(buf, &self.#ident); }); @@ -298,11 +359,11 @@ fn impl_fields_unnamed(name: Ident, f: FieldsUnnamed) -> impl Into fn impl_fields_unit(name: Ident) -> impl Into { quote! { impl ::msgpacker::Packable for #name { - fn pack(&self, _buf: &mut T) -> usize + fn pack(&self, buf: &mut T) -> usize where T: Extend, { - 0 + ::msgpacker::derive_util::get_array_info(buf, 0) } } @@ -310,14 +371,25 @@ fn impl_fields_unit(name: Ident) -> impl Into { type Error = ::msgpacker::Error; fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> { - Ok((0, Self)) + let format = ::msgpacker::derive_util::take_byte(&mut buf)?; + let (_, len) = match format { + 0x90 => (1, 0), + _ => return Err(Error::UnexpectedFormatTag.into()), + }; + Ok((1, Self)) } fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> where I: IntoIterator, { - Ok((0, Self)) + let mut bytes = bytes.into_iter(); + let format = ::msgpacker::derive_util::take_byte_iter(bytes.by_ref())?; + let (_, len) = match format { + 0x90 => (1, 0), + _ => return Err(Error::UnexpectedFormatTag.into()), + }; + Ok((1, Self)) } } } diff --git a/msgpacker/Cargo.toml b/msgpacker/Cargo.toml index aaa8549..7f10527 100644 --- a/msgpacker/Cargo.toml +++ b/msgpacker/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "msgpacker" -version = "0.4.7" +version = "0.5.0" authors = ["Victor Lopez "] categories = ["compression", "encoding", "parser-implementations"] edition = "2021" @@ -11,7 +11,7 @@ repository = "https://github.com/codx-dev/msgpacker" description = "MessagePack protocol implementation for Rust." [dependencies] -msgpacker-derive = { version = "0.3", optional = true } +msgpacker-derive = { path = "../msgpacker-derive", optional = true } [dev-dependencies] proptest = "1.2" diff --git a/msgpacker/src/error.rs b/msgpacker/src/error.rs index 16e3ed4..ddeaffb 100644 --- a/msgpacker/src/error.rs +++ b/msgpacker/src/error.rs @@ -15,6 +15,10 @@ pub enum Error { UnexpectedFormatTag, /// The provided bin length is not valid. UnexpectedBinLength, + /// The struct we're targeting does not match the data. + UnexpectedStructLength, + /// The array we're targeting does not match the data. + UnexpectedArrayLength, } impl fmt::Display for Error { diff --git a/msgpacker/src/format.rs b/msgpacker/src/format.rs index 7345495..e1d22fb 100644 --- a/msgpacker/src/format.rs +++ b/msgpacker/src/format.rs @@ -1,40 +1,73 @@ +/// A container for the Format constants pub struct Format {} impl Format { + /// Nil format stores nil in 1 byte. pub const NIL: u8 = 0xc0; + /// Bool format family stores false or true in 1 byte. pub const TRUE: u8 = 0xc3; + /// Bool format family stores false or true in 1 byte. pub const FALSE: u8 = 0xc2; + /// Positive fixint stores 7-bit positive integer pub const POSITIVE_FIXINT: u8 = 0x7f; + /// Uint 8 stores a 8-bit unsigned integer pub const UINT8: u8 = 0xcc; + /// Uint 16 stores a 16-bit big-endian unsigned integer pub const UINT16: u8 = 0xcd; + /// Uint 32 stores a 32-bit big-endian unsigned integer pub const UINT32: u8 = 0xce; + /// Uint 64 stores a 64-bit big-endian unsigned integer pub const UINT64: u8 = 0xcf; + /// Int 8 stores a 8-bit signed integer pub const INT8: u8 = 0xd0; + /// Int 16 stores a 16-bit big-endian signed integer pub const INT16: u8 = 0xd1; + /// Int 32 stores a 32-bit big-endian signed integer pub const INT32: u8 = 0xd2; + /// Int 64 stores a 64-bit big-endian signed integer pub const INT64: u8 = 0xd3; + /// Float 32 stores a floating point number in IEEE 754 single precision floating point number pub const FLOAT32: u8 = 0xca; + /// Float 64 stores a floating point number in IEEE 754 double precision floating point number pub const FLOAT64: u8 = 0xcb; + /// Bin 8 stores a byte array whose length is upto (2^8)-1 bytes pub const BIN8: u8 = 0xc4; + /// Bin 16 stores a byte array whose length is upto (2^16)-1 bytes pub const BIN16: u8 = 0xc5; + /// Bin 32 stores a byte array whose length is upto (2^32)-1 bytes pub const BIN32: u8 = 0xc6; + /// Str 8 stores a byte array whose length is upto (2^8)-1 bytes pub const STR8: u8 = 0xd9; + /// Str 16 stores a byte array whose length is upto (2^16)-1 bytes pub const STR16: u8 = 0xda; + /// Str 32 stores a byte array whose length is upto (2^32)-1 bytes pub const STR32: u8 = 0xdb; + /// Array 16 stores an array whose length is upto (2^16)-1 elements pub const ARRAY16: u8 = 0xdc; + /// Array 32 stores an array whose length is upto (2^32)-1 elements pub const ARRAY32: u8 = 0xdd; + /// Map 16 stores a map whose length is upto (2^16)-1 elements pub const MAP16: u8 = 0xde; + /// Map 32 stores a map whose length is upto (2^32)-1 elements pub const MAP32: u8 = 0xdf; } #[cfg(feature = "alloc")] impl Format { + /// Fixext 1 stores an integer and a byte array whose length is 1 byte pub const FIXEXT1: u8 = 0xd4; + /// Fixext 2 stores an integer and a byte array whose length is 2 byte pub const FIXEXT2: u8 = 0xd5; + /// Fixext 4 stores an integer and a byte array whose length is 4 byte pub const FIXEXT4: u8 = 0xd6; + /// Fixext 8 stores an integer and a byte array whose length is 8 byte pub const FIXEXT8: u8 = 0xd7; + /// Fixext 16 stores an integer and a byte array whose length is 16 byte pub const FIXEXT16: u8 = 0xd8; + /// Ext 8 stores an integer and a byte array whose length is upto (2^8)-1 bytes pub const EXT8: u8 = 0xc7; + /// Ext 16 stores an integer and a byte array whose length is upto (2^16)-1 bytes pub const EXT16: u8 = 0xc8; + /// Ext 32 stores an integer and a byte array whose length is upto (2^32)-1 bytes pub const EXT32: u8 = 0xc9; } diff --git a/msgpacker/src/helpers.rs b/msgpacker/src/helpers.rs index 8efb98e..f30baba 100644 --- a/msgpacker/src/helpers.rs +++ b/msgpacker/src/helpers.rs @@ -1,12 +1,6 @@ use super::Error; -pub fn take_byte_iter(mut bytes: I) -> Result -where - I: Iterator, -{ - bytes.next().ok_or(Error::BufferTooShort) -} - +/// Take one byte off the provided buffer, advance the pointer, or error. pub fn take_byte(buf: &mut &[u8]) -> Result { if buf.is_empty() { return Err(Error::BufferTooShort); @@ -16,6 +10,15 @@ pub fn take_byte(buf: &mut &[u8]) -> Result { Ok(l[0]) } +/// Take one byte from the iterator or error. +pub fn take_byte_iter(mut bytes: I) -> Result +where + I: Iterator, +{ + bytes.next().ok_or(Error::BufferTooShort) +} + +/// Read bytes off the buffer, using the provided function, or error. pub fn take_num(buf: &mut &[u8], f: fn([u8; N]) -> V) -> Result { if buf.len() < N { return Err(Error::BufferTooShort); @@ -27,6 +30,18 @@ pub fn take_num(buf: &mut &[u8], f: fn([u8; N]) -> V) -> Resu Ok(f(val)) } +/// Read a number off the iterator, using the provided function, or error. +pub fn take_num_iter(mut bytes: I, f: fn([u8; N]) -> V) -> Result +where + I: Iterator, +{ + let mut array = [0u8; N]; // Initialize with zeroes + for byte in array.iter_mut() { + *byte = bytes.next().ok_or(Error::BufferTooShort)?; + } + Ok(f(array)) +} + #[cfg(feature = "alloc")] pub fn take_buffer<'a>(buf: &mut &'a [u8], len: usize) -> Result<&'a [u8], Error> { if buf.len() < len { @@ -37,29 +52,6 @@ pub fn take_buffer<'a>(buf: &mut &'a [u8], len: usize) -> Result<&'a [u8], Error Ok(l) } -pub fn take_num_iter(bytes: I, f: fn([u8; N]) -> V) -> Result -where - I: Iterator, -{ - let mut array = [0u8; N]; - let mut i = 0; - - for b in bytes { - array[i] = b; - i += 1; - - if i == N { - break; - } - } - - if i < N { - return Err(Error::BufferTooShort); - } - - Ok(f(array)) -} - #[cfg(feature = "alloc")] pub fn take_buffer_iter(bytes: I, len: usize) -> Result, Error> where diff --git a/msgpacker/src/lib.rs b/msgpacker/src/lib.rs index 845444d..6e077d6 100644 --- a/msgpacker/src/lib.rs +++ b/msgpacker/src/lib.rs @@ -16,8 +16,18 @@ mod unpack; pub use error::Error; use format::Format; -pub use pack::{pack_array, pack_map}; -pub use unpack::{unpack_array, unpack_array_iter, unpack_map, unpack_map_iter}; +pub use pack::{pack_array, pack_binary, pack_map}; +pub use unpack::{ + unpack_array, unpack_array_iter, unpack_binary, unpack_binary_iter, unpack_map, unpack_map_iter, +}; + +/// This module exposes some utility variables and functions for msgpacker-derive +#[cfg(feature = "derive")] +pub mod derive_util { + pub use crate::format::Format; + pub use crate::helpers::{take_byte, take_byte_iter, take_num, take_num_iter}; + pub use crate::pack::get_array_info; +} #[cfg(feature = "alloc")] pub use extension::Extension; diff --git a/msgpacker/src/pack/binary.rs b/msgpacker/src/pack/binary.rs index 01f1d2c..c73d798 100644 --- a/msgpacker/src/pack/binary.rs +++ b/msgpacker/src/pack/binary.rs @@ -1,29 +1,30 @@ use super::{Format, Packable}; use core::iter; -impl Packable for [u8] { - #[allow(unreachable_code)] - fn pack(&self, buf: &mut T) -> usize - where - T: Extend, - { - let n = if self.len() <= u8::MAX as usize { - buf.extend(iter::once(Format::BIN8).chain(iter::once(self.len() as u8))); - 2 - } else if self.len() <= u16::MAX as usize { - buf.extend(iter::once(Format::BIN16).chain((self.len() as u16).to_be_bytes())); - 3 - } else if self.len() <= u32::MAX as usize { - buf.extend(iter::once(Format::BIN32).chain((self.len() as u32).to_be_bytes())); - 5 - } else { - #[cfg(feature = "strict")] - panic!("strict serialization enabled; the buffer is too large"); - return 0; - }; - buf.extend(self.iter().copied()); - n + self.len() - } +/// Packs a u8 array as binary data into the extendable buffer, returning the amount of written bytes. +#[allow(unreachable_code)] +pub fn pack_binary(buf: &mut T, data: &[u8]) -> usize +where + T: Extend, +{ + let len = data.len(); + + let n = if len <= u8::MAX as usize { + buf.extend(iter::once(Format::BIN8).chain(iter::once(len as u8))); + 2 + } else if len <= u16::MAX as usize { + buf.extend(iter::once(Format::BIN16).chain((len as u16).to_be_bytes())); + 3 + } else if len <= u32::MAX as usize { + buf.extend(iter::once(Format::BIN32).chain((len as u32).to_be_bytes())); + 5 + } else { + #[cfg(feature = "strict")] + panic!("strict serialization enabled; the buffer is too large"); + return 0; + }; + buf.extend(data.iter().copied()); + n + len } #[allow(unreachable_code)] @@ -57,23 +58,23 @@ impl Packable for str { #[cfg(feature = "alloc")] mod alloc { use super::*; - use ::alloc::{string::String, vec::Vec}; + use ::alloc::string::String; - impl Packable for Vec { + impl Packable for String { fn pack(&self, buf: &mut T) -> usize where T: Extend, { - self.as_slice().pack(buf) + self.as_str().pack(buf) } } - impl Packable for String { + impl Packable for Box<[u8]> { fn pack(&self, buf: &mut T) -> usize where T: Extend, { - self.as_str().pack(buf) + pack_binary(buf, self) } } } diff --git a/msgpacker/src/pack/collections.rs b/msgpacker/src/pack/collections.rs index bafee5f..b181665 100644 --- a/msgpacker/src/pack/collections.rs +++ b/msgpacker/src/pack/collections.rs @@ -1,18 +1,12 @@ use super::{Format, Packable}; use core::{borrow::Borrow, iter}; -/// Packs an array into the extendable buffer, returning the amount of written bytes. -#[allow(unreachable_code)] -pub fn pack_array(buf: &mut T, iter: A) -> usize +/// Writes the info for an array into the extendable buffer, returning the amount of written bytes. +pub fn get_array_info(buf: &mut T, len: usize) -> usize where T: Extend, - A: IntoIterator, - I: Iterator + ExactSizeIterator, - V: Packable, { - let values = iter.into_iter(); - let len = values.len(); - let n = if len <= 15 { + if len <= 15 { buf.extend(iter::once(((len & 0x0f) as u8) | 0x90)); 1 } else if len <= u16::MAX as usize { @@ -25,7 +19,22 @@ where #[cfg(feature = "strict")] panic!("strict serialization enabled; the buffer is too large"); return 0; - }; + } +} + +/// Packs an array into the extendable buffer, returning the amount of written bytes. +#[allow(unreachable_code)] +pub fn pack_array(buf: &mut T, iter: A) -> usize +where + T: Extend, + A: IntoIterator, + I: Iterator + ExactSizeIterator, + V: Packable, +{ + let values = iter.into_iter(); + let len = values.len(); + + let n = get_array_info(buf, len); n + values.map(|v| v.pack(buf)).sum::() } @@ -129,6 +138,18 @@ mod alloc { pack_map(buf, self) } } + + impl Packable for Vec + where + X: Packable, + { + fn pack(&self, buf: &mut T) -> usize + where + T: Extend, + { + pack_array(buf, self) + } + } } #[cfg(feature = "std")] diff --git a/msgpacker/src/pack/common.rs b/msgpacker/src/pack/common.rs index 9c946d3..f419518 100644 --- a/msgpacker/src/pack/common.rs +++ b/msgpacker/src/pack/common.rs @@ -1,4 +1,4 @@ -use super::{Format, Packable}; +use super::{get_array_info, Format, Packable}; use core::{iter, marker::PhantomData}; impl Packable for () { @@ -61,7 +61,9 @@ macro_rules! array { where T: Extend, { - self.iter().map(|t| t.pack(buf)).sum() + let len = self.len(); + let n = get_array_info(buf, len); + n + self.iter().map(|t| t.pack(buf)).sum::() } } }; diff --git a/msgpacker/src/pack/mod.rs b/msgpacker/src/pack/mod.rs index d31ce85..6d7f04f 100644 --- a/msgpacker/src/pack/mod.rs +++ b/msgpacker/src/pack/mod.rs @@ -6,4 +6,5 @@ mod common; mod float; mod int; -pub use collections::{pack_array, pack_map}; +pub use binary::pack_binary; +pub use collections::{get_array_info, pack_array, pack_map}; diff --git a/msgpacker/src/unpack/binary.rs b/msgpacker/src/unpack/binary.rs index b49bf7a..8056729 100644 --- a/msgpacker/src/unpack/binary.rs +++ b/msgpacker/src/unpack/binary.rs @@ -10,7 +10,8 @@ use super::{ use alloc::{string::String, vec::Vec}; use core::str; -pub fn unpack_bytes(mut buf: &[u8]) -> Result<(usize, &[u8]), Error> { +/// Unpacks binary data from the buffer, returning a &[u8] and the amount of read bytes. +pub fn unpack_binary(mut buf: &[u8]) -> Result<(usize, &[u8]), Error> { let format = take_byte(&mut buf)?; let (n, len) = match format { Format::BIN8 => (2, take_byte(&mut buf)? as usize), @@ -24,6 +25,32 @@ pub fn unpack_bytes(mut buf: &[u8]) -> Result<(usize, &[u8]), Error> { Ok((n + len, &buf[..len])) } +/// Unpacks binary data from the iterator, returning a Vec and the amount of read bytes. +pub fn unpack_binary_iter(bytes: I) -> Result<(usize, Vec), Error> +where + I: IntoIterator, +{ + let mut bytes = bytes.into_iter(); + let format = take_byte_iter(bytes.by_ref())?; + let (n, len) = match format { + Format::BIN8 => (2, take_byte_iter(bytes.by_ref())? as usize), + Format::BIN16 => ( + 3, + take_num_iter(bytes.by_ref(), u16::from_be_bytes)? as usize, + ), + Format::BIN32 => ( + 5, + take_num_iter(bytes.by_ref(), u32::from_be_bytes)? as usize, + ), + _ => return Err(Error::UnexpectedFormatTag), + }; + let v: Vec<_> = bytes.take(len).collect(); + if v.len() < len { + return Err(Error::BufferTooShort); + } + Ok((n + len, v)) +} + pub fn unpack_str(mut buf: &[u8]) -> Result<(usize, &str), Error> { let format = take_byte(&mut buf)?; let (n, len) = match format { @@ -40,11 +67,11 @@ pub fn unpack_str(mut buf: &[u8]) -> Result<(usize, &str), Error> { Ok((n + len, str)) } -impl Unpackable for Vec { +impl Unpackable for String { type Error = Error; fn unpack(buf: &[u8]) -> Result<(usize, Self), Self::Error> { - unpack_bytes(buf).map(|(n, b)| (n, b.to_vec())) + unpack_str(buf).map(|(n, s)| (n, s.into())) } fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> @@ -54,12 +81,13 @@ impl Unpackable for Vec { let mut bytes = bytes.into_iter(); let format = take_byte_iter(bytes.by_ref())?; let (n, len) = match format { - Format::BIN8 => (2, take_byte_iter(bytes.by_ref())? as usize), - Format::BIN16 => ( + 0xa0..=0xbf => (1, format as usize & 0x1f), + Format::STR8 => (2, take_byte_iter(bytes.by_ref())? as usize), + Format::STR16 => ( 3, take_num_iter(bytes.by_ref(), u16::from_be_bytes)? as usize, ), - Format::BIN32 => ( + Format::STR32 => ( 5, take_num_iter(bytes.by_ref(), u32::from_be_bytes)? as usize, ), @@ -69,15 +97,16 @@ impl Unpackable for Vec { if v.len() < len { return Err(Error::BufferTooShort); } - Ok((n + len, v)) + let s = String::from_utf8(v).map_err(|_| Error::InvalidUtf8)?; + Ok((n + len, s)) } } -impl Unpackable for String { +impl Unpackable for Box<[u8]> { type Error = Error; fn unpack(buf: &[u8]) -> Result<(usize, Self), Self::Error> { - unpack_str(buf).map(|(n, s)| (n, s.into())) + unpack_binary(buf).map(|(n, b)| (n, b.to_vec().into_boxed_slice())) } fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> @@ -87,13 +116,12 @@ impl Unpackable for String { let mut bytes = bytes.into_iter(); let format = take_byte_iter(bytes.by_ref())?; let (n, len) = match format { - 0xa0..=0xbf => (1, format as usize & 0x1f), - Format::STR8 => (2, take_byte_iter(bytes.by_ref())? as usize), - Format::STR16 => ( + Format::BIN8 => (2, take_byte_iter(bytes.by_ref())? as usize), + Format::BIN16 => ( 3, take_num_iter(bytes.by_ref(), u16::from_be_bytes)? as usize, ), - Format::STR32 => ( + Format::BIN32 => ( 5, take_num_iter(bytes.by_ref(), u32::from_be_bytes)? as usize, ), @@ -103,7 +131,6 @@ impl Unpackable for String { if v.len() < len { return Err(Error::BufferTooShort); } - let s = String::from_utf8(v).map_err(|_| Error::InvalidUtf8)?; - Ok((n + len, s)) + Ok((n + len, v.into_boxed_slice())) } } diff --git a/msgpacker/src/unpack/collections.rs b/msgpacker/src/unpack/collections.rs index 4bf9ff5..f0d15af 100644 --- a/msgpacker/src/unpack/collections.rs +++ b/msgpacker/src/unpack/collections.rs @@ -22,6 +22,7 @@ where ), _ => return Err(Error::UnexpectedFormatTag.into()), }; + let array: C = (0..len) .map(|_| { let (count, v) = V::unpack(buf)?; @@ -230,6 +231,24 @@ mod alloc { unpack_map_iter(bytes) } } + + impl Unpackable for Vec + where + X: Unpackable, + { + type Error = ::Error; + + fn unpack(buf: &[u8]) -> Result<(usize, Self), Self::Error> { + unpack_array(buf) + } + + fn unpack_iter(bytes: I) -> Result<(usize, Self), Self::Error> + where + I: IntoIterator, + { + unpack_array_iter(bytes) + } + } } #[cfg(feature = "std")] diff --git a/msgpacker/src/unpack/common.rs b/msgpacker/src/unpack/common.rs index 58821b5..4646c65 100644 --- a/msgpacker/src/unpack/common.rs +++ b/msgpacker/src/unpack/common.rs @@ -1,8 +1,9 @@ use super::{ - helpers::{take_byte, take_byte_iter}, + helpers::{take_byte, take_byte_iter, take_num, take_num_iter}, Error, Format, Unpackable, }; use core::{marker::PhantomData, mem::MaybeUninit}; +use std::ptr; impl Unpackable for () { type Error = Error; @@ -98,21 +99,35 @@ macro_rules! array { fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> { let mut array = [const { MaybeUninit::uninit() }; $n]; - let n = - array - .iter_mut() - .try_fold::<_, _, Result<_, Self::Error>>(0, |count, a| { - let (n, x) = X::unpack(buf)?; - buf = &buf[n..]; - a.write(x); - Ok(count + n) - })?; + + let format = take_byte(&mut buf)?; + let (mut n, len) = match format { + 0x90..=0x9f => (1, (format & 0x0f) as usize), + Format::ARRAY16 => ( + 3, + take_num(&mut buf, u16::from_be_bytes).map(|v| v as usize)?, + ), + Format::ARRAY32 => ( + 5, + take_num(&mut buf, u32::from_be_bytes).map(|v| v as usize)?, + ), + _ => return Err(Error::UnexpectedFormatTag.into()), + }; + + if len != $n { + return Err(Error::UnexpectedArrayLength.into()); + } + + n += array + .iter_mut() + .try_fold::<_, _, Result<_, Self::Error>>(0, |count, a| { + let (n, x) = X::unpack(buf)?; + buf = &buf[n..]; + a.write(x); + Ok(count + n) + })?; // Safety: array is initialized - let array = ::core::array::from_fn(|i| { - let mut x = MaybeUninit::zeroed(); - ::core::mem::swap(&mut array[i], &mut x); - unsafe { MaybeUninit::assume_init(x) } - }); + let array = unsafe { ptr::read(&array as *const _ as *const [X; $n]) }; Ok((n, array)) } @@ -122,20 +137,34 @@ macro_rules! array { { let mut bytes = bytes.into_iter(); let mut array = [const { MaybeUninit::uninit() }; $n]; - let n = - array - .iter_mut() - .try_fold::<_, _, Result<_, Self::Error>>(0, |count, a| { - let (n, x) = X::unpack_iter(bytes.by_ref())?; - a.write(x); - Ok(count + n) - })?; + + let format = take_byte_iter(bytes.by_ref())?; + let (mut n, len) = match format { + 0x90..=0x9f => (1, (format & 0x0f) as usize), + Format::ARRAY16 => ( + 3, + take_num_iter(bytes.by_ref(), u16::from_be_bytes).map(|v| v as usize)?, + ), + Format::ARRAY32 => ( + 5, + take_num_iter(bytes.by_ref(), u32::from_be_bytes).map(|v| v as usize)?, + ), + _ => return Err(Error::UnexpectedFormatTag.into()), + }; + + if len != $n { + return Err(Error::UnexpectedArrayLength.into()); + } + + n += array + .iter_mut() + .try_fold::<_, _, Result<_, Self::Error>>(0, |count, a| { + let (n, x) = X::unpack_iter(bytes.by_ref())?; + a.write(x); + Ok(count + n) + })?; // Safety: array is initialized - let array = ::core::array::from_fn(|i| { - let mut x = MaybeUninit::zeroed(); - ::core::mem::swap(&mut array[i], &mut x); - unsafe { MaybeUninit::assume_init(x) } - }); + let array = unsafe { ptr::read(&array as *const _ as *const [X; $n]) }; Ok((n, array)) } } diff --git a/msgpacker/src/unpack/mod.rs b/msgpacker/src/unpack/mod.rs index 5f20331..a6b8f4e 100644 --- a/msgpacker/src/unpack/mod.rs +++ b/msgpacker/src/unpack/mod.rs @@ -7,4 +7,5 @@ mod common; mod float; mod int; +pub use binary::{unpack_binary, unpack_binary_iter}; pub use collections::{unpack_array, unpack_array_iter, unpack_map, unpack_map_iter}; diff --git a/msgpacker/tests/binary.rs b/msgpacker/tests/binary.rs index ce79bbe..6e68711 100644 --- a/msgpacker/tests/binary.rs +++ b/msgpacker/tests/binary.rs @@ -5,11 +5,11 @@ mod utils; #[test] fn empty_vec() { - let v = vec![]; + let v: Vec = vec![]; let mut bytes = vec![]; - let n = v.pack(&mut bytes); - let (o, x) = Vec::::unpack(&bytes).unwrap(); - let (p, y) = Vec::::unpack_iter(bytes).unwrap(); + let n = msgpacker::pack_binary(&mut bytes, &v); + let (o, x) = msgpacker::unpack_binary(&bytes).unwrap(); + let (p, y) = msgpacker::unpack_binary_iter(bytes.clone()).unwrap(); assert_eq!(o, n); assert_eq!(p, n); assert_eq!(v, x); @@ -30,6 +30,19 @@ fn empty_str() { } proptest! { + #[test] + fn slice(value: Box<[u8]>) { + let mut bytes = Vec::new(); + let n = msgpacker::pack_binary(&mut bytes, &value); + assert_eq!(n, bytes.len()); + let (o, x): (usize, &[u8]) = msgpacker::unpack_binary(&bytes).unwrap(); + let (p, y): (usize, Vec) = msgpacker::unpack_binary_iter(bytes.clone()).unwrap(); + assert_eq!(n, o); + assert_eq!(n, p); + assert_eq!(value, x.into()); + assert_eq!(value, y.into()); + } + #[test] fn vec(v: Vec) { utils::case(v); diff --git a/msgpacker/tests/collections.rs b/msgpacker/tests/collections.rs index 57b19fb..301e0ee 100644 --- a/msgpacker/tests/collections.rs +++ b/msgpacker/tests/collections.rs @@ -49,6 +49,19 @@ proptest! { assert_eq!(value, y); } + #[test] + fn slice(value: [Value;8]) { + let mut bytes = Vec::new(); + let n = msgpacker::pack_array(&mut bytes, &value); + assert_eq!(n, bytes.len()); + let (o, x): (usize, Vec) = msgpacker::unpack_array(&bytes).unwrap(); + let (p, y): (usize, Vec) = msgpacker::unpack_array_iter(bytes).unwrap(); + assert_eq!(n, o); + assert_eq!(n, p); + assert_eq!(value, x.as_slice()); + assert_eq!(value, y.as_slice()); + } + #[test] fn map(map: HashMap) { let mut bytes = Vec::new();