summaryrefslogtreecommitdiff
path: root/gen_random_proc_macro
diff options
context:
space:
mode:
Diffstat (limited to 'gen_random_proc_macro')
-rw-r--r--gen_random_proc_macro/src/lib.rs65
1 files changed, 54 insertions, 11 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs
index 1a620c3..fa2d5dd 100644
--- a/gen_random_proc_macro/src/lib.rs
+++ b/gen_random_proc_macro/src/lib.rs
@@ -90,6 +90,21 @@ fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens {
}
}
+// very very precise summation algorithm
+// see https://en.wikipedia.org/wiki/Kahan_summation_algorithm
+fn kahan_sum(it: impl IntoIterator<Item = f64>) -> f64 {
+ let mut it = it.into_iter();
+ let mut sum = 0.0;
+ let mut c = 0.0;
+ while let Some(x) = it.next() {
+ let y = x - c;
+ let t = sum + y;
+ c = (t - sum) - y;
+ sum = t;
+ }
+ sum
+}
+
fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let mut function_body;
@@ -97,35 +112,63 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
match &ast.data {
syn::Data::Enum(enumeration) => {
let variants = &enumeration.variants;
- let epsilon: f64 = 1e-9;
- let one_minus_epsilon = 1.0 - epsilon;
+
+ let prob_sum = kahan_sum(variants.iter().map(|variant| {
+ match parse_attribute_value(&variant.attrs, "prob") {
+ Some(prob) => if prob >= 0.0 {
+ prob
+ } else {
+ panic!("Variant {} has negative probability", variant.ident)
+ },
+ None => panic!("Variant {} has no probability", variant.ident)
+ }
+ }));
+
+ if prob_sum <= f64::EPSILON {
+ panic!("Sum of probabilties is (basically) zero.");
+ }
+
+ // ideally we would just do
+ // let mut variant: f64 = rng.gen_range(0.0..1.0);
+ // variant -= variant1_probability;
+ // if variant < 0.0 { bla bla bla }
+ // variant -= variant2_probability;
+ // if variant < 0.0 { bla bla bla }
+ // etc.
+ // but because of floating point imprecision, it's possible
+ // that all if conditions are false.
+ // however we know that for each subtraction at most one ULP is lost.
+ // so we'll be fine as long as we put the end of the range at
+ // prob_sum * (1.0 - (variant_count + 2) * ULP)
+ // the + 2 is for the imprecision lost in kahan_sum and one more just to be sure.
+
+ let variant_max = prob_sum * (1.0 - f64::EPSILON * (variants.len() + 2) as f64);
function_body = quote! {
- let mut variant: f64 = rng.gen_range(0.0..#one_minus_epsilon);
+ let mut variant: f64 = rng.gen_range(0.0..=#variant_max);
};
- let mut test_variant = one_minus_epsilon;
+ // this test value ensures that the gen_random function never panicks.
+ let mut test_variant = variant_max;
// parse enum fields
for variant in variants.iter() {
- let probability: Option<f64> = parse_attribute_value(&variant.attrs, "prob");
- let Some(probability) = probability else {
- panic!("Variant {} has no probability", variant.ident)
- };
+ // Note: None case was checked above when computing prob_sum
+ let probability: f64 = parse_attribute_value(&variant.attrs, "prob").unwrap();
let name = &variant.ident;
let field_values = generate_fields(&variant.fields);
function_body.extend(quote! {
variant -= #probability;
- if variant <= 0.0 { return Self::#name #field_values; }
+ if variant < 0.0 { return Self::#name #field_values; }
});
test_variant -= probability;
}
- if test_variant >= 0.0 || test_variant < -2.0 * epsilon {
- panic!("Probabilities for enum do not add up to 1 (final test_variant = {test_variant}).");
+ if test_variant >= 0.0 {
+ panic!("i did floating-point math wrong. this should never happen. (test_variant = {test_variant})");
}
function_body.extend(quote! {