summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gen_random_proc_macro/src/lib.rs65
-rw-r--r--gen_random_test/src/lib.rs33
2 files changed, 70 insertions, 28 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs
index 1a620c3..fa2d5dd 100644
--- a/gen_random_proc_macro/src/lib.rs
+++ b/gen_random_proc_macro/src/lib.rs
@@ -90,6 +90,21 @@ fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens {
}
}
+// very very precise summation algorithm
+// see https://en.wikipedia.org/wiki/Kahan_summation_algorithm
+fn kahan_sum(it: impl IntoIterator<Item = f64>) -> f64 {
+ let mut it = it.into_iter();
+ let mut sum = 0.0;
+ let mut c = 0.0;
+ while let Some(x) = it.next() {
+ let y = x - c;
+ let t = sum + y;
+ c = (t - sum) - y;
+ sum = t;
+ }
+ sum
+}
+
fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let mut function_body;
@@ -97,35 +112,63 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
match &ast.data {
syn::Data::Enum(enumeration) => {
let variants = &enumeration.variants;
- let epsilon: f64 = 1e-9;
- let one_minus_epsilon = 1.0 - epsilon;
+
+ let prob_sum = kahan_sum(variants.iter().map(|variant| {
+ match parse_attribute_value(&variant.attrs, "prob") {
+ Some(prob) => if prob >= 0.0 {
+ prob
+ } else {
+ panic!("Variant {} has negative probability", variant.ident)
+ },
+ None => panic!("Variant {} has no probability", variant.ident)
+ }
+ }));
+
+ if prob_sum <= f64::EPSILON {
+ panic!("Sum of probabilties is (basically) zero.");
+ }
+
+ // ideally we would just do
+ // let mut variant: f64 = rng.gen_range(0.0..1.0);
+ // variant -= variant1_probability;
+ // if variant < 0.0 { bla bla bla }
+ // variant -= variant2_probability;
+ // if variant < 0.0 { bla bla bla }
+ // etc.
+ // but because of floating point imprecision, it's possible
+ // that all if conditions are false.
+ // however we know that for each subtraction at most one ULP is lost.
+ // so we'll be fine as long as we put the end of the range at
+ // prob_sum * (1.0 - (variant_count + 2) * ULP)
+ // the + 2 is for the imprecision lost in kahan_sum and one more just to be sure.
+
+ let variant_max = prob_sum * (1.0 - f64::EPSILON * (variants.len() + 2) as f64);
function_body = quote! {
- let mut variant: f64 = rng.gen_range(0.0..#one_minus_epsilon);
+ let mut variant: f64 = rng.gen_range(0.0..=#variant_max);
};
- let mut test_variant = one_minus_epsilon;
+ // this test value ensures that the gen_random function never panicks.
+ let mut test_variant = variant_max;
// parse enum fields
for variant in variants.iter() {
- let probability: Option<f64> = parse_attribute_value(&variant.attrs, "prob");
- let Some(probability) = probability else {
- panic!("Variant {} has no probability", variant.ident)
- };
+ // Note: None case was checked above when computing prob_sum
+ let probability: f64 = parse_attribute_value(&variant.attrs, "prob").unwrap();
let name = &variant.ident;
let field_values = generate_fields(&variant.fields);
function_body.extend(quote! {
variant -= #probability;
- if variant <= 0.0 { return Self::#name #field_values; }
+ if variant < 0.0 { return Self::#name #field_values; }
});
test_variant -= probability;
}
- if test_variant >= 0.0 || test_variant < -2.0 * epsilon {
- panic!("Probabilities for enum do not add up to 1 (final test_variant = {test_variant}).");
+ if test_variant >= 0.0 {
+ panic!("i did floating-point math wrong. this should never happen. (test_variant = {test_variant})");
}
function_body.extend(quote! {
diff --git a/gen_random_test/src/lib.rs b/gen_random_test/src/lib.rs
index bea6190..0d2b1cc 100644
--- a/gen_random_test/src/lib.rs
+++ b/gen_random_test/src/lib.rs
@@ -1,39 +1,38 @@
-
#[cfg(test)]
mod tests {
- extern crate rand;
- extern crate gen_random_proc_macro;
extern crate gen_random;
- use gen_random::{GenRandom, gen_thread_random_vec};
+ extern crate gen_random_proc_macro;
+ extern crate rand;
+ use gen_random::{gen_thread_random_vec, GenRandom};
use gen_random_proc_macro::GenRandom;
-
+
#[derive(GenRandom, Debug)]
enum Test1 {
#[prob = 0.2]
A(f32),
#[prob = 0.8]
- B(Option<f32>)
+ B(Option<f32>),
}
-
+
#[derive(GenRandom, Debug)]
#[allow(dead_code)]
enum Test2 {
#[prob = 0.1]
Variant1,
#[prob = 0.7]
- Variant2 { x : f32, y: f64, z: Test1 },
+ Variant2 { x: f32, y: f64, z: Test1 },
#[prob = 0.2]
- Variant3(f32, Box<Test2>)
+ Variant3(f32, Box<Test2>),
}
-
+
#[derive(GenRandom, Debug)]
enum LinkedList {
- #[prob = 0.1]
+ #[prob = 10]
Empty,
- #[prob = 0.9]
- Cons(f32, Box<LinkedList>)
+ #[prob = 90]
+ Cons(f32, Box<LinkedList>),
}
-
+
#[derive(GenRandom, Debug)]
struct ScaleBias {
#[bias = 1.0]
@@ -49,19 +48,19 @@ mod tests {
let tests1: Vec<Test1> = gen_thread_random_vec(10);
println!("{tests1:?}");
}
-
+
#[test]
fn many_types_of_variants() {
let tests2: Vec<Test2> = gen_thread_random_vec(10);
println!("{tests2:?}");
}
-
+
#[test]
fn linked_list() {
let ll = LinkedList::gen_thread_random();
println!("{ll:?}");
}
-
+
#[test]
fn scale_bias() {
let sb: Vec<ScaleBias> = gen_thread_random_vec(10);