diff options
author | pommicket <pommicket@gmail.com> | 2022-12-14 15:24:18 -0500 |
---|---|---|
committer | pommicket <pommicket@gmail.com> | 2022-12-14 15:24:18 -0500 |
commit | 3f77c9f224c935aa56793bd548944e991cd4b0cd (patch) | |
tree | 0b28a7fa316a23f9e8254e74be96629e2e0be916 /gen_random_proc_macro/src | |
parent | e8bc993ed558e1b25a31e6a6fabac7853e1b1035 (diff) |
probabilities which don't add up to 1
Diffstat (limited to 'gen_random_proc_macro/src')
-rw-r--r-- | gen_random_proc_macro/src/lib.rs | 65 |
1 files changed, 54 insertions, 11 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs index 1a620c3..fa2d5dd 100644 --- a/gen_random_proc_macro/src/lib.rs +++ b/gen_random_proc_macro/src/lib.rs @@ -90,6 +90,21 @@ fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens { } } +// very very precise summation algorithm +// see https://en.wikipedia.org/wiki/Kahan_summation_algorithm +fn kahan_sum(it: impl IntoIterator<Item = f64>) -> f64 { + let mut it = it.into_iter(); + let mut sum = 0.0; + let mut c = 0.0; + while let Some(x) = it.next() { + let y = x - c; + let t = sum + y; + c = (t - sum) - y; + sum = t; + } + sum +} + fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { let name = &ast.ident; let mut function_body; @@ -97,35 +112,63 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { match &ast.data { syn::Data::Enum(enumeration) => { let variants = &enumeration.variants; - let epsilon: f64 = 1e-9; - let one_minus_epsilon = 1.0 - epsilon; + + let prob_sum = kahan_sum(variants.iter().map(|variant| { + match parse_attribute_value(&variant.attrs, "prob") { + Some(prob) => if prob >= 0.0 { + prob + } else { + panic!("Variant {} has negative probability", variant.ident) + }, + None => panic!("Variant {} has no probability", variant.ident) + } + })); + + if prob_sum <= f64::EPSILON { + panic!("Sum of probabilties is (basically) zero."); + } + + // ideally we would just do + // let mut variant: f64 = rng.gen_range(0.0..1.0); + // variant -= variant1_probability; + // if variant < 0.0 { bla bla bla } + // variant -= variant2_probability; + // if variant < 0.0 { bla bla bla } + // etc. + // but because of floating point imprecision, it's possible + // that all if conditions are false. + // however we know that for each subtraction at most one ULP is lost. + // so we'll be fine as long as we put the end of the range at + // prob_sum * (1.0 - (variant_count + 2) * ULP) + // the + 2 is for the imprecision lost in kahan_sum and one more just to be sure. + + let variant_max = prob_sum * (1.0 - f64::EPSILON * (variants.len() + 2) as f64); function_body = quote! { - let mut variant: f64 = rng.gen_range(0.0..#one_minus_epsilon); + let mut variant: f64 = rng.gen_range(0.0..=#variant_max); }; - let mut test_variant = one_minus_epsilon; + // this test value ensures that the gen_random function never panicks. + let mut test_variant = variant_max; // parse enum fields for variant in variants.iter() { - let probability: Option<f64> = parse_attribute_value(&variant.attrs, "prob"); - let Some(probability) = probability else { - panic!("Variant {} has no probability", variant.ident) - }; + // Note: None case was checked above when computing prob_sum + let probability: f64 = parse_attribute_value(&variant.attrs, "prob").unwrap(); let name = &variant.ident; let field_values = generate_fields(&variant.fields); function_body.extend(quote! { variant -= #probability; - if variant <= 0.0 { return Self::#name #field_values; } + if variant < 0.0 { return Self::#name #field_values; } }); test_variant -= probability; } - if test_variant >= 0.0 || test_variant < -2.0 * epsilon { - panic!("Probabilities for enum do not add up to 1 (final test_variant = {test_variant})."); + if test_variant >= 0.0 { + panic!("i did floating-point math wrong. this should never happen. (test_variant = {test_variant})"); } function_body.extend(quote! { |