summaryrefslogtreecommitdiff
path: root/gen_random_proc_macro/src
diff options
context:
space:
mode:
authorpommicket <pommicket@gmail.com>2022-12-21 15:10:22 -0500
committerpommicket <pommicket@gmail.com>2022-12-21 15:10:22 -0500
commit1bc45db77e3d1aaf6c620248b8e598cdc212112f (patch)
treeeb534b13feb2e4e8ca1a6783b22f46f0fc043019 /gen_random_proc_macro/src
parentc81d53fa47863d80436ce808b16c836ea6d3e16c (diff)
params, twisty
Diffstat (limited to 'gen_random_proc_macro/src')
-rw-r--r--gen_random_proc_macro/src/lib.rs93
1 files changed, 31 insertions, 62 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs
index 9b7a771..961b2a0 100644
--- a/gen_random_proc_macro/src/lib.rs
+++ b/gen_random_proc_macro/src/lib.rs
@@ -10,7 +10,7 @@ use std::any::type_name;
use quote::quote;
/// See `gen_random::GenRandom`.
-#[proc_macro_derive(GenRandom, attributes(prob, scale, bias))]
+#[proc_macro_derive(GenRandom, attributes(prob, scale, bias, params, only_if))]
pub fn gen_random_derive(input: TokenStream) -> TokenStream {
// Construct a representation of Rust code as a syntax tree
// that we can manipulate
@@ -19,7 +19,7 @@ pub fn gen_random_derive(input: TokenStream) -> TokenStream {
impl_gen_random(&ast)
}
-fn get_attribute(attrs: &[syn::Attribute], name: &str) -> Option<proc_macro2::TokenStream> {
+fn get_attribute(attrs: &[syn::Attribute], name: &str) -> Option<TokenStream2> {
let attr = attrs.iter().find(|a| {
let path = &a.path;
if let Some(ident) = path.get_ident() {
@@ -56,14 +56,14 @@ fn parse_attribute_value<T: FromStr>(attrs: &[syn::Attribute], name: &str) -> Op
Some(value)
}
-fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens {
+fn generate_fields(fields: &syn::Fields, params_type: &TokenStream2) -> impl quote::ToTokens {
let mut field_values = quote! {};
for field in fields.iter() {
if let Some(name) = &field.ident {
field_values.extend(quote! {#name: });
}
let ty = &field.ty;
- field_values.extend(quote! { <#ty as GenRandom>::gen_random_max_depth(rng, _depth - 1) });
+ field_values.extend(quote! { <#ty as GenRandom<#params_type>>::gen_random_params(rng, <#params_type as GenRandomParams>::inc_depth(params)) });
if let Some(scale) = get_attribute(&field.attrs, "scale") {
field_values.extend(quote! { * ( #scale ) });
@@ -93,95 +93,64 @@ fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens {
}
}
-// very very precise summation algorithm
-// see https://en.wikipedia.org/wiki/Kahan_summation_algorithm
-fn kahan_sum(it: impl IntoIterator<Item = f64>) -> f64 {
- let mut it = it.into_iter();
- let mut sum = 0.0;
- let mut c = 0.0;
- while let Some(x) = it.next() {
- let y = x - c;
- let t = sum + y;
- c = (t - sum) - y;
- sum = t;
- }
- sum
-}
-
fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let mut function_body;
+ let params_type = get_attribute(&ast.attrs, "params").unwrap_or(quote! { () });
match &ast.data {
syn::Data::Enum(enumeration) => {
let variants = &enumeration.variants;
+ function_body = quote! {
+ let mut prob_sum = 0.0;
+ };
- let prob_sum = kahan_sum(variants.iter().map(|variant| {
- match parse_attribute_value(&variant.attrs, "prob") {
+ for variant in variants.iter() {
+ match parse_attribute_value::<f64>(&variant.attrs, "prob") {
Some(prob) => if prob >= 0.0 {
- prob
+ let only_if = get_attribute(&variant.attrs, "only_if")
+ .unwrap_or(quote! { true });
+
+ function_body.extend(quote! {
+ if #only_if { prob_sum += #prob; }
+ });
} else {
panic!("Variant {} has negative probability", variant.ident)
},
None => panic!("Variant {} has no probability", variant.ident)
}
- }));
-
- if prob_sum <= f64::EPSILON {
- panic!("Sum of probabilties is (basically) zero.");
}
- // ideally we would just do
- // let mut variant: f64 = rng.gen_range(0.0..prob_sum);
- // variant -= variant1_probability;
- // if variant < 0.0 { bla bla bla }
- // variant -= variant2_probability;
- // if variant < 0.0 { bla bla bla }
- // etc.
- // but because of floating point imprecision, it's possible
- // that all if conditions are false.
- // however we know that for each subtraction at most one ULP is lost.
- // so we'll be fine as long as we put the end of the range at
- // prob_sum * (1.0 - (variant_count + 2) * ULP)
- // the + 2 is for the imprecision lost in kahan_sum and one more just to be sure.
-
- let variant_max = prob_sum * (1.0 - f64::EPSILON * (variants.len() + 2) as f64);
- function_body = quote! {
- let mut variant: f64 = rng.gen_range(0.0..=#variant_max);
- };
-
- // this test value ensures that the gen_random function never panicks.
- let mut test_variant = variant_max;
+ let compensation = (variants.len() + 1) as f64 * f64::EPSILON;
+ function_body.extend(quote! {
+ let mut variant: f64 = rng.gen_range(0.0..prob_sum - #compensation);
+ });
// parse enum fields
for variant in variants.iter() {
// Note: None case was checked above when computing prob_sum
let probability: f64 = parse_attribute_value(&variant.attrs, "prob").unwrap();
+ let only_if = get_attribute(&variant.attrs, "only_if")
+ .unwrap_or(quote! { true });
let name = &variant.ident;
- let field_values = generate_fields(&variant.fields);
+ let field_values = generate_fields(&variant.fields, &params_type);
function_body.extend(quote! {
- variant -= #probability;
- // note: if _depth <= 0, we will always return the first variant.
- if _depth <= 0 || variant < 0.0 { return Self::#name #field_values; }
+ if #only_if {
+ variant -= #probability;
+ if variant < 0.0 { return Self::#name #field_values; }
+ }
});
- test_variant -= probability;
-
-
- }
-
- if test_variant >= 0.0 {
- panic!("i did floating-point math wrong. this should never happen. (test_variant = {test_variant})");
}
function_body.extend(quote! {
- panic!("RNG returned value outside of range.")
+ panic!("RNG returned value outside of range (this should never happen).")
});
},
syn::Data::Struct(structure) => {
- let field_values = generate_fields(&structure.fields);
+ let field_values = generate_fields(&structure.fields, &params_type);
function_body = quote! {
Self #field_values
};
@@ -190,8 +159,8 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
};
let gen = quote! {
- impl GenRandom for #name {
- fn gen_random_max_depth(rng: &mut impl rand::Rng, _depth: isize) -> Self {
+ impl GenRandom<#params_type> for #name {
+ fn gen_random_params(rng: &mut impl rand::Rng, params: #params_type) -> Self {
#function_body
}
}