diff options
author | pommicket <pommicket@gmail.com> | 2022-12-21 15:10:22 -0500 |
---|---|---|
committer | pommicket <pommicket@gmail.com> | 2022-12-21 15:10:22 -0500 |
commit | 1bc45db77e3d1aaf6c620248b8e598cdc212112f (patch) | |
tree | eb534b13feb2e4e8ca1a6783b22f46f0fc043019 /gen_random_proc_macro/src | |
parent | c81d53fa47863d80436ce808b16c836ea6d3e16c (diff) |
params, twisty
Diffstat (limited to 'gen_random_proc_macro/src')
-rw-r--r-- | gen_random_proc_macro/src/lib.rs | 93 |
1 files changed, 31 insertions, 62 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs index 9b7a771..961b2a0 100644 --- a/gen_random_proc_macro/src/lib.rs +++ b/gen_random_proc_macro/src/lib.rs @@ -10,7 +10,7 @@ use std::any::type_name; use quote::quote; /// See `gen_random::GenRandom`. -#[proc_macro_derive(GenRandom, attributes(prob, scale, bias))] +#[proc_macro_derive(GenRandom, attributes(prob, scale, bias, params, only_if))] pub fn gen_random_derive(input: TokenStream) -> TokenStream { // Construct a representation of Rust code as a syntax tree // that we can manipulate @@ -19,7 +19,7 @@ pub fn gen_random_derive(input: TokenStream) -> TokenStream { impl_gen_random(&ast) } -fn get_attribute(attrs: &[syn::Attribute], name: &str) -> Option<proc_macro2::TokenStream> { +fn get_attribute(attrs: &[syn::Attribute], name: &str) -> Option<TokenStream2> { let attr = attrs.iter().find(|a| { let path = &a.path; if let Some(ident) = path.get_ident() { @@ -56,14 +56,14 @@ fn parse_attribute_value<T: FromStr>(attrs: &[syn::Attribute], name: &str) -> Op Some(value) } -fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens { +fn generate_fields(fields: &syn::Fields, params_type: &TokenStream2) -> impl quote::ToTokens { let mut field_values = quote! {}; for field in fields.iter() { if let Some(name) = &field.ident { field_values.extend(quote! {#name: }); } let ty = &field.ty; - field_values.extend(quote! { <#ty as GenRandom>::gen_random_max_depth(rng, _depth - 1) }); + field_values.extend(quote! { <#ty as GenRandom<#params_type>>::gen_random_params(rng, <#params_type as GenRandomParams>::inc_depth(params)) }); if let Some(scale) = get_attribute(&field.attrs, "scale") { field_values.extend(quote! { * ( #scale ) }); @@ -93,95 +93,64 @@ fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens { } } -// very very precise summation algorithm -// see https://en.wikipedia.org/wiki/Kahan_summation_algorithm -fn kahan_sum(it: impl IntoIterator<Item = f64>) -> f64 { - let mut it = it.into_iter(); - let mut sum = 0.0; - let mut c = 0.0; - while let Some(x) = it.next() { - let y = x - c; - let t = sum + y; - c = (t - sum) - y; - sum = t; - } - sum -} - fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { let name = &ast.ident; let mut function_body; + let params_type = get_attribute(&ast.attrs, "params").unwrap_or(quote! { () }); match &ast.data { syn::Data::Enum(enumeration) => { let variants = &enumeration.variants; + function_body = quote! { + let mut prob_sum = 0.0; + }; - let prob_sum = kahan_sum(variants.iter().map(|variant| { - match parse_attribute_value(&variant.attrs, "prob") { + for variant in variants.iter() { + match parse_attribute_value::<f64>(&variant.attrs, "prob") { Some(prob) => if prob >= 0.0 { - prob + let only_if = get_attribute(&variant.attrs, "only_if") + .unwrap_or(quote! { true }); + + function_body.extend(quote! { + if #only_if { prob_sum += #prob; } + }); } else { panic!("Variant {} has negative probability", variant.ident) }, None => panic!("Variant {} has no probability", variant.ident) } - })); - - if prob_sum <= f64::EPSILON { - panic!("Sum of probabilties is (basically) zero."); } - // ideally we would just do - // let mut variant: f64 = rng.gen_range(0.0..prob_sum); - // variant -= variant1_probability; - // if variant < 0.0 { bla bla bla } - // variant -= variant2_probability; - // if variant < 0.0 { bla bla bla } - // etc. - // but because of floating point imprecision, it's possible - // that all if conditions are false. - // however we know that for each subtraction at most one ULP is lost. - // so we'll be fine as long as we put the end of the range at - // prob_sum * (1.0 - (variant_count + 2) * ULP) - // the + 2 is for the imprecision lost in kahan_sum and one more just to be sure. - - let variant_max = prob_sum * (1.0 - f64::EPSILON * (variants.len() + 2) as f64); - function_body = quote! { - let mut variant: f64 = rng.gen_range(0.0..=#variant_max); - }; - - // this test value ensures that the gen_random function never panicks. - let mut test_variant = variant_max; + let compensation = (variants.len() + 1) as f64 * f64::EPSILON; + function_body.extend(quote! { + let mut variant: f64 = rng.gen_range(0.0..prob_sum - #compensation); + }); // parse enum fields for variant in variants.iter() { // Note: None case was checked above when computing prob_sum let probability: f64 = parse_attribute_value(&variant.attrs, "prob").unwrap(); + let only_if = get_attribute(&variant.attrs, "only_if") + .unwrap_or(quote! { true }); let name = &variant.ident; - let field_values = generate_fields(&variant.fields); + let field_values = generate_fields(&variant.fields, ¶ms_type); function_body.extend(quote! { - variant -= #probability; - // note: if _depth <= 0, we will always return the first variant. - if _depth <= 0 || variant < 0.0 { return Self::#name #field_values; } + if #only_if { + variant -= #probability; + if variant < 0.0 { return Self::#name #field_values; } + } }); - test_variant -= probability; - - - } - - if test_variant >= 0.0 { - panic!("i did floating-point math wrong. this should never happen. (test_variant = {test_variant})"); } function_body.extend(quote! { - panic!("RNG returned value outside of range.") + panic!("RNG returned value outside of range (this should never happen).") }); }, syn::Data::Struct(structure) => { - let field_values = generate_fields(&structure.fields); + let field_values = generate_fields(&structure.fields, ¶ms_type); function_body = quote! { Self #field_values }; @@ -190,8 +159,8 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { }; let gen = quote! { - impl GenRandom for #name { - fn gen_random_max_depth(rng: &mut impl rand::Rng, _depth: isize) -> Self { + impl GenRandom<#params_type> for #name { + fn gen_random_params(rng: &mut impl rand::Rng, params: #params_type) -> Self { #function_body } } |