diff options
Diffstat (limited to 'gen_random_proc_macro')
-rw-r--r-- | gen_random_proc_macro/src/lib.rs | 170 |
1 files changed, 103 insertions, 67 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs index e621a76..1a620c3 100644 --- a/gen_random_proc_macro/src/lib.rs +++ b/gen_random_proc_macro/src/lib.rs @@ -2,10 +2,14 @@ extern crate quote; extern crate proc_macro2; extern crate syn; -use proc_macro::{TokenStream, TokenTree::{self, Punct, Literal}}; +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::TokenTree as TokenTree2; +use std::str::FromStr; +use std::any::type_name; use quote::quote; -#[proc_macro_derive(GenRandom, attributes(prob))] +#[proc_macro_derive(GenRandom, attributes(prob, scale, bias))] pub fn gen_random_derive(input: TokenStream) -> TokenStream { // Construct a representation of Rust code as a syntax tree // that we can manipulate @@ -14,15 +18,88 @@ pub fn gen_random_derive(input: TokenStream) -> TokenStream { impl_gen_random(&ast) } +fn get_attribute_literal(attrs: &[syn::Attribute], name: &str) -> Option<proc_macro2::Literal> { + let attr = attrs.iter().find(|a| { + let path = &a.path; + if let Some(ident) = path.get_ident() { + ident == name + } else { + false + } + })?; + + let tokens: TokenStream2 = attr.tokens.clone().into(); + let mut tokens: Vec<TokenTree2> = tokens.into_iter().collect(); + if tokens.len() != 2 { + panic!("Expected {name} = <value>"); + } + use TokenTree2::{Punct, Literal}; + match &tokens[0] { + Punct(equals) if equals.as_char() == '=' => {} + _ => panic!("Expected = after {name} attribute"), + }; + + let Literal(literal) = tokens.remove(1) else { + panic!("Bad value for {name} attribute."); + }; + Some(literal) +} + +fn parse_attribute_value<T: FromStr>(attrs: &[syn::Attribute], name: &str) -> Option<T> { + let literal = get_attribute_literal(attrs, name)?; + let Ok(value) = literal.to_string().parse() else { + panic!("Bad {} for {name} attribute", type_name::<T>()) + }; + Some(value) +} + +fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens { + let mut field_values = quote! {}; + for field in fields.iter() { + if let Some(name) = &field.ident { + field_values.extend(quote! {#name: }); + } + let ty = &field.ty; + field_values.extend(quote! { <#ty as GenRandom>::gen_random(rng) }); + + if let Some(scale) = get_attribute_literal(&field.attrs, "scale") { + field_values.extend(quote! { * #scale }); + } + if let Some(bias) = get_attribute_literal(&field.attrs, "bias") { + field_values.extend(quote! { + #bias }); + } + + field_values.extend(quote! { , }); + } + + // surround the field values with either () or {} brackets + match fields { + syn::Fields::Named(_) => { + Some(proc_macro2::Group::new( + proc_macro2::Delimiter::Brace, + field_values + )) + }, + syn::Fields::Unnamed(_) => { + Some(proc_macro2::Group::new( + proc_macro2::Delimiter::Parenthesis, + field_values + )) + }, + syn::Fields::Unit => None, + } +} fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { let name = &ast.ident; + let mut function_body; + match &ast.data { syn::Data::Enum(enumeration) => { let variants = &enumeration.variants; let epsilon: f64 = 1e-9; let one_minus_epsilon = 1.0 - epsilon; - let mut function_body = quote! { + function_body = quote! { let mut variant: f64 = rng.gen_range(0.0..#one_minus_epsilon); }; @@ -30,66 +107,17 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { // parse enum fields for variant in variants.iter() { - let probability: f64 = { - let attr = variant.attrs.iter().find(|a| { - let path = &a.path; - if let Some(ident) = path.get_ident() { - ident == "prob" - } else { - false - } - }); - let Some(attr) = attr else { - panic!("Variant {} has no probability", variant.ident) - }; - let tokens: TokenStream = attr.tokens.clone().into(); - let tokens: Vec<TokenTree> = tokens.into_iter().collect(); - if tokens.len() != 2 { - panic!("Expected prob = <floating-point number>"); - } - match &tokens[0] { - Punct(equals) if equals.as_char() == '=' => {} - _ => panic!("Expected = after prob attribute"), - }; - - let Literal(literal) = &tokens[1] else { - panic!("Bad number for prob attribute."); - }; - literal.to_string().parse().expect("Bad number for prob attribute") + let probability: Option<f64> = parse_attribute_value(&variant.attrs, "prob"); + let Some(probability) = probability else { + panic!("Variant {} has no probability", variant.ident) }; let name = &variant.ident; - - - let mut variant_arguments = quote! {}; - for field in variant.fields.iter() { - if let Some(name) = &field.ident { - variant_arguments.extend(quote! {#name: }); - } - let ty = &field.ty; - variant_arguments.extend(quote! { <#ty as GenRandom>::gen_random(rng), }); - } - - // surround the arguments with either () or {} brackets - let constructor_group = match variant.fields { - syn::Fields::Named(_) => { - Some(proc_macro2::Group::new( - proc_macro2::Delimiter::Brace, - variant_arguments - )) - }, - syn::Fields::Unnamed(_) => { - Some(proc_macro2::Group::new( - proc_macro2::Delimiter::Parenthesis, - variant_arguments - )) - }, - syn::Fields::Unit => None, - }; + let field_values = generate_fields(&variant.fields); function_body.extend(quote! { variant -= #probability; - if variant <= 0.0 { return Self::#name #constructor_group; } + if variant <= 0.0 { return Self::#name #field_values; } }); test_variant -= probability; @@ -104,17 +132,25 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { panic!("RNG returned value outside of range.") }); - let gen = quote! { - impl GenRandom for #name { - fn gen_random(rng: &mut impl rand::Rng) -> Self { - #function_body - } - } + }, + syn::Data::Struct(structure) => { + let field_values = generate_fields(&structure.fields); + function_body = quote! { + Self #field_values }; - //println!("{gen}"); - gen.into() }, - _ => panic!("derive(GenRandom) can currently only be applied to enums."), - } + syn::Data::Union(_) => panic!("derive(GenRandom) cannot be applied to unions."), + }; + + let gen = quote! { + impl GenRandom for #name { + fn gen_random(rng: &mut impl rand::Rng) -> Self { + #function_body + } + } + }; + + //println!("{gen}"); + gen.into() } |