Skip to content

Commit c9764cd

Browse files
author
Oliver Scherer
committed
Implement derives for generic wrapper types
1 parent ed849ab commit c9764cd

File tree

2 files changed

+133
-15
lines changed

2 files changed

+133
-15
lines changed

src/lib.rs

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,13 @@ impl NumTraits {
232232
pub fn from_primitive(input: TokenStream) -> TokenStream {
233233
let ast: syn::DeriveInput = syn::parse(input).unwrap();
234234
let name = &ast.ident;
235+
let (impl_, type_, where_) = split_for_impl(&ast.generics);
235236

236237
let import = NumTraits::new(&ast);
237238

238239
let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) {
239240
quote! {
240-
impl #import::FromPrimitive for #name {
241+
impl #impl_ #import::FromPrimitive for #name #type_ #where_ #inner_ty: #import::FromPrimitive {
241242
fn from_i64(n: i64) -> Option<Self> {
242243
<#inner_ty as #import::FromPrimitive>::from_i64(n).map(#name)
243244
}
@@ -320,7 +321,7 @@ pub fn from_primitive(input: TokenStream) -> TokenStream {
320321
};
321322

322323
quote! {
323-
impl #import::FromPrimitive for #name {
324+
impl #impl_ #import::FromPrimitive for #name #type_ #where_ {
324325
#[allow(trivial_numeric_casts)]
325326
fn from_i64(#from_i64_var: i64) -> Option<Self> {
326327
#(#clauses else)* {
@@ -390,12 +391,13 @@ pub fn from_primitive(input: TokenStream) -> TokenStream {
390391
pub fn to_primitive(input: TokenStream) -> TokenStream {
391392
let ast: syn::DeriveInput = syn::parse(input).unwrap();
392393
let name = &ast.ident;
394+
let (impl_, type_, where_) = split_for_impl(&ast.generics);
393395

394396
let import = NumTraits::new(&ast);
395397

396398
let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) {
397399
quote! {
398-
impl #import::ToPrimitive for #name {
400+
impl #impl_ #import::ToPrimitive for #name #type_ #where_ #inner_ty: #import::ToPrimitive {
399401
fn to_i64(&self) -> Option<i64> {
400402
<#inner_ty as #import::ToPrimitive>::to_i64(&self.0)
401403
}
@@ -481,7 +483,7 @@ pub fn to_primitive(input: TokenStream) -> TokenStream {
481483
};
482484

483485
quote! {
484-
impl #import::ToPrimitive for #name {
486+
impl #impl_ #import::ToPrimitive for #name #type_ #where_ {
485487
#[allow(trivial_numeric_casts)]
486488
fn to_i64(&self) -> Option<i64> {
487489
#match_expr
@@ -511,33 +513,34 @@ const NEWTYPE_ONLY: &str = "This trait can only be derived for newtypes";
511513
pub fn num_ops(input: TokenStream) -> TokenStream {
512514
let ast: syn::DeriveInput = syn::parse(input).unwrap();
513515
let name = &ast.ident;
516+
let (impl_, type_, where_) = split_for_impl(&ast.generics);
514517
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
515518
let impl_ = quote! {
516-
impl ::std::ops::Add for #name {
519+
impl #impl_ ::std::ops::Add for #name #type_ #where_ #inner_ty: ::std::ops::Add<Output = #inner_ty> {
517520
type Output = Self;
518521
fn add(self, other: Self) -> Self {
519522
#name(<#inner_ty as ::std::ops::Add>::add(self.0, other.0))
520523
}
521524
}
522-
impl ::std::ops::Sub for #name {
525+
impl #impl_ ::std::ops::Sub for #name #type_ #where_ #inner_ty: ::std::ops::Sub<Output = #inner_ty> {
523526
type Output = Self;
524527
fn sub(self, other: Self) -> Self {
525528
#name(<#inner_ty as ::std::ops::Sub>::sub(self.0, other.0))
526529
}
527530
}
528-
impl ::std::ops::Mul for #name {
531+
impl #impl_ ::std::ops::Mul for #name #type_ #where_ #inner_ty: ::std::ops::Mul<Output = #inner_ty> {
529532
type Output = Self;
530533
fn mul(self, other: Self) -> Self {
531534
#name(<#inner_ty as ::std::ops::Mul>::mul(self.0, other.0))
532535
}
533536
}
534-
impl ::std::ops::Div for #name {
537+
impl #impl_ ::std::ops::Div for #name #type_ #where_ #inner_ty: ::std::ops::Div<Output = #inner_ty> {
535538
type Output = Self;
536539
fn div(self, other: Self) -> Self {
537540
#name(<#inner_ty as ::std::ops::Div>::div(self.0, other.0))
538541
}
539542
}
540-
impl ::std::ops::Rem for #name {
543+
impl #impl_ ::std::ops::Rem for #name #type_ #where_ #inner_ty: ::std::ops::Rem<Output = #inner_ty> {
541544
type Output = Self;
542545
fn rem(self, other: Self) -> Self {
543546
#name(<#inner_ty as ::std::ops::Rem>::rem(self.0, other.0))
@@ -555,13 +558,16 @@ pub fn num_ops(input: TokenStream) -> TokenStream {
555558
pub fn num_cast(input: TokenStream) -> TokenStream {
556559
let ast: syn::DeriveInput = syn::parse(input).unwrap();
557560
let name = &ast.ident;
561+
let (impl_, type_, where_) = split_for_impl(&ast.generics);
558562
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
563+
let fn_param = proc_macro2::Ident::new("FROM_T", name.span());
559564

560565
let import = NumTraits::new(&ast);
561566

562567
let impl_ = quote! {
563-
impl #import::NumCast for #name {
564-
fn from<T: #import::ToPrimitive>(n: T) -> Option<Self> {
568+
impl #impl_ #import::NumCast for #name #type_ #where_ #inner_ty: #import::NumCast {
569+
#[allow(non_camel_case_types)]
570+
fn from<#fn_param: #import::ToPrimitive>(n: #fn_param) -> Option<Self> {
565571
<#inner_ty as #import::NumCast>::from(n).map(#name)
566572
}
567573
}
@@ -577,12 +583,13 @@ pub fn num_cast(input: TokenStream) -> TokenStream {
577583
pub fn zero(input: TokenStream) -> TokenStream {
578584
let ast: syn::DeriveInput = syn::parse(input).unwrap();
579585
let name = &ast.ident;
586+
let (impl_, type_, where_) = split_for_impl(&ast.generics);
580587
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
581588

582589
let import = NumTraits::new(&ast);
583590

584591
let impl_ = quote! {
585-
impl #import::Zero for #name {
592+
impl #impl_ #import::Zero for #name #type_ #where_ #inner_ty: #import::Zero {
586593
fn zero() -> Self {
587594
#name(<#inner_ty as #import::Zero>::zero())
588595
}
@@ -602,12 +609,13 @@ pub fn zero(input: TokenStream) -> TokenStream {
602609
pub fn one(input: TokenStream) -> TokenStream {
603610
let ast: syn::DeriveInput = syn::parse(input).unwrap();
604611
let name = &ast.ident;
612+
let (impl_, type_, where_) = split_for_impl(&ast.generics);
605613
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
606614

607615
let import = NumTraits::new(&ast);
608616

609617
let impl_ = quote! {
610-
impl #import::One for #name {
618+
impl #impl_ #import::One for #name #type_ #where_ #inner_ty: #import::One + PartialEq {
611619
fn one() -> Self {
612620
#name(<#inner_ty as #import::One>::one())
613621
}
@@ -620,19 +628,31 @@ pub fn one(input: TokenStream) -> TokenStream {
620628
import.wrap("One", &name, impl_).into()
621629
}
622630

631+
fn split_for_impl(
632+
generics: &syn::Generics,
633+
) -> (syn::ImplGenerics, syn::TypeGenerics, impl quote::ToTokens) {
634+
let (impl_, type_, where_) = generics.split_for_impl();
635+
let where_ = match where_ {
636+
Some(where_) => quote! { #where_, },
637+
None => quote! { where },
638+
};
639+
(impl_, type_, where_)
640+
}
641+
623642
/// Derives [`num_traits::Num`][num] for newtypes. The inner type must already implement `Num`.
624643
///
625644
/// [num]: https://docs.rs/num-traits/0.2/num_traits/trait.Num.html
626645
#[proc_macro_derive(Num, attributes(num_traits))]
627646
pub fn num(input: TokenStream) -> TokenStream {
628647
let ast: syn::DeriveInput = syn::parse(input).unwrap();
629648
let name = &ast.ident;
649+
let (impl_, type_, where_) = split_for_impl(&ast.generics);
630650
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
631651

632652
let import = NumTraits::new(&ast);
633653

634654
let impl_ = quote! {
635-
impl #import::Num for #name {
655+
impl #impl_ #import::Num for #name #type_ #where_ #inner_ty: #import::Num {
636656
type FromStrRadixErr = <#inner_ty as #import::Num>::FromStrRadixErr;
637657
fn from_str_radix(s: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
638658
<#inner_ty as #import::Num>::from_str_radix(s, radix).map(#name)
@@ -651,12 +671,13 @@ pub fn num(input: TokenStream) -> TokenStream {
651671
pub fn float(input: TokenStream) -> TokenStream {
652672
let ast: syn::DeriveInput = syn::parse(input).unwrap();
653673
let name = &ast.ident;
674+
let (impl_, type_, where_) = split_for_impl(&ast.generics);
654675
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
655676

656677
let import = NumTraits::new(&ast);
657678

658679
let impl_ = quote! {
659-
impl #import::Float for #name {
680+
impl #impl_ #import::Float for #name #type_ #where_ #inner_ty: #import::Float {
660681
fn nan() -> Self {
661682
#name(<#inner_ty as #import::Float>::nan())
662683
}

tests/generic_newtype.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
extern crate num as num_renamed;
2+
#[macro_use]
3+
extern crate num_derive;
4+
5+
use crate::num_renamed::{Float, FromPrimitive, Num, NumCast, One, ToPrimitive, Zero};
6+
use std::ops::Neg;
7+
8+
#[derive(
9+
Debug,
10+
Clone,
11+
Copy,
12+
PartialEq,
13+
PartialOrd,
14+
ToPrimitive,
15+
FromPrimitive,
16+
NumOps,
17+
NumCast,
18+
One,
19+
Zero,
20+
Num,
21+
Float,
22+
)]
23+
struct MyThing<T: Cake>(T)
24+
where
25+
T: Lie;
26+
27+
trait Cake {}
28+
trait Lie {}
29+
30+
impl Cake for f32 {}
31+
impl Lie for f32 {}
32+
33+
impl<T: Neg<Output = T> + Cake + Lie> Neg for MyThing<T> {
34+
type Output = Self;
35+
fn neg(self) -> Self {
36+
MyThing(self.0.neg())
37+
}
38+
}
39+
40+
#[test]
41+
fn test_from_primitive() {
42+
assert_eq!(MyThing::from_u32(25), Some(MyThing(25.0)));
43+
}
44+
45+
#[test]
46+
fn test_from_primitive_128() {
47+
assert_eq!(
48+
MyThing::from_i128(std::i128::MIN),
49+
Some(MyThing((-2.0).powi(127)))
50+
);
51+
}
52+
53+
#[test]
54+
fn test_to_primitive() {
55+
assert_eq!(MyThing(25.0).to_u32(), Some(25));
56+
}
57+
58+
#[test]
59+
fn test_to_primitive_128() {
60+
let f: MyThing<f32> = MyThing::from_f32(std::f32::MAX).unwrap();
61+
assert_eq!(f.to_i128(), None);
62+
assert_eq!(f.to_u128(), Some(0xffff_ff00_0000_0000_0000_0000_0000_0000));
63+
}
64+
65+
#[test]
66+
fn test_num_ops() {
67+
assert_eq!(MyThing(25.0) + MyThing(10.0), MyThing(35.0));
68+
assert_eq!(MyThing(25.0) - MyThing(10.0), MyThing(15.0));
69+
assert_eq!(MyThing(25.0) * MyThing(2.0), MyThing(50.0));
70+
assert_eq!(MyThing(25.0) / MyThing(10.0), MyThing(2.5));
71+
assert_eq!(MyThing(25.0) % MyThing(10.0), MyThing(5.0));
72+
}
73+
74+
#[test]
75+
fn test_num_cast() {
76+
assert_eq!(<MyThing<f32> as NumCast>::from(25u8), Some(MyThing(25.0)));
77+
}
78+
79+
#[test]
80+
fn test_zero() {
81+
assert_eq!(MyThing::zero(), MyThing(0.0));
82+
}
83+
84+
#[test]
85+
fn test_one() {
86+
assert_eq!(MyThing::one(), MyThing(1.0));
87+
}
88+
89+
#[test]
90+
fn test_num() {
91+
assert_eq!(MyThing::from_str_radix("25", 10).ok(), Some(MyThing(25.0)));
92+
}
93+
94+
#[test]
95+
fn test_float() {
96+
assert_eq!(MyThing(4.0).log(MyThing(2.0)), MyThing(2.0));
97+
}

0 commit comments

Comments
 (0)