summaryrefslogtreecommitdiff
path: root/gen_random_proc_macro
diff options
context:
space:
mode:
authorpommicket <pommicket@gmail.com>2022-12-14 16:16:57 -0500
committerpommicket <pommicket@gmail.com>2022-12-14 16:16:57 -0500
commiteff66f8056b01a732df9523cb3a3d06b2d69c750 (patch)
treebe4f2bb7ea2dc765c0c7ded04be92993cab1ba4c /gen_random_proc_macro
parent3f77c9f224c935aa56793bd548944e991cd4b0cd (diff)
nicer GenRandom
Diffstat (limited to 'gen_random_proc_macro')
-rw-r--r--gen_random_proc_macro/src/lib.rs10
1 files changed, 6 insertions, 4 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs
index fa2d5dd..5e28e16 100644
--- a/gen_random_proc_macro/src/lib.rs
+++ b/gen_random_proc_macro/src/lib.rs
@@ -9,6 +9,7 @@ use std::str::FromStr;
use std::any::type_name;
use quote::quote;
+/// See `gen_random::GenRandom`.
#[proc_macro_derive(GenRandom, attributes(prob, scale, bias))]
pub fn gen_random_derive(input: TokenStream) -> TokenStream {
// Construct a representation of Rust code as a syntax tree
@@ -60,7 +61,7 @@ fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens {
field_values.extend(quote! {#name: });
}
let ty = &field.ty;
- field_values.extend(quote! { <#ty as GenRandom>::gen_random(rng) });
+ field_values.extend(quote! { <#ty as GenRandom>::gen_random_max_depth(rng, _depth - 1) });
if let Some(scale) = get_attribute_literal(&field.attrs, "scale") {
field_values.extend(quote! { * #scale });
@@ -129,7 +130,7 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
}
// ideally we would just do
- // let mut variant: f64 = rng.gen_range(0.0..1.0);
+ // let mut variant: f64 = rng.gen_range(0.0..prob_sum);
// variant -= variant1_probability;
// if variant < 0.0 { bla bla bla }
// variant -= variant2_probability;
@@ -160,7 +161,8 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
function_body.extend(quote! {
variant -= #probability;
- if variant < 0.0 { return Self::#name #field_values; }
+ // note: if _depth <= 0, we will always return the first variant.
+ if _depth <= 0 || variant < 0.0 { return Self::#name #field_values; }
});
test_variant -= probability;
@@ -187,7 +189,7 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
let gen = quote! {
impl GenRandom for #name {
- fn gen_random(rng: &mut impl rand::Rng) -> Self {
+ fn gen_random_max_depth(rng: &mut impl rand::Rng, _depth: isize) -> Self {
#function_body
}
}