summaryrefslogtreecommitdiff
path: root/gen_random_proc_macro/src
diff options
context:
space:
mode:
authorpommicket <pommicket@gmail.com>2022-12-14 12:45:48 -0500
committerpommicket <pommicket@gmail.com>2022-12-14 12:45:48 -0500
commit3265cb676c5c87fd624f59aaf3ca89d947df8cda (patch)
tree0bc94cbdddb9cecfc6205f834428e6efafcdb69a /gen_random_proc_macro/src
parent3f839f2ec08ffbb75bee391202e7e0d7432d97e0 (diff)
GenRandom
Diffstat (limited to 'gen_random_proc_macro/src')
-rw-r--r--gen_random_proc_macro/src/lib.rs120
1 files changed, 120 insertions, 0 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs
new file mode 100644
index 0000000..e621a76
--- /dev/null
+++ b/gen_random_proc_macro/src/lib.rs
@@ -0,0 +1,120 @@
+extern crate quote;
+extern crate proc_macro2;
+extern crate syn;
+
+use proc_macro::{TokenStream, TokenTree::{self, Punct, Literal}};
+use quote::quote;
+
+#[proc_macro_derive(GenRandom, attributes(prob))]
+pub fn gen_random_derive(input: TokenStream) -> TokenStream {
+ // Construct a representation of Rust code as a syntax tree
+ // that we can manipulate
+ let ast = syn::parse(input).unwrap();
+ // Build the trait implementation
+ impl_gen_random(&ast)
+}
+
+
+fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
+ let name = &ast.ident;
+ match &ast.data {
+ syn::Data::Enum(enumeration) => {
+ let variants = &enumeration.variants;
+ let epsilon: f64 = 1e-9;
+ let one_minus_epsilon = 1.0 - epsilon;
+ let mut function_body = quote! {
+ let mut variant: f64 = rng.gen_range(0.0..#one_minus_epsilon);
+ };
+
+ let mut test_variant = one_minus_epsilon;
+
+ // parse enum fields
+ for variant in variants.iter() {
+ let probability: f64 = {
+ let attr = variant.attrs.iter().find(|a| {
+ let path = &a.path;
+ if let Some(ident) = path.get_ident() {
+ ident == "prob"
+ } else {
+ false
+ }
+ });
+ let Some(attr) = attr else {
+ panic!("Variant {} has no probability", variant.ident)
+ };
+ let tokens: TokenStream = attr.tokens.clone().into();
+ let tokens: Vec<TokenTree> = tokens.into_iter().collect();
+ if tokens.len() != 2 {
+ panic!("Expected prob = <floating-point number>");
+ }
+ match &tokens[0] {
+ Punct(equals) if equals.as_char() == '=' => {}
+ _ => panic!("Expected = after prob attribute"),
+ };
+
+ let Literal(literal) = &tokens[1] else {
+ panic!("Bad number for prob attribute.");
+ };
+ literal.to_string().parse().expect("Bad number for prob attribute")
+ };
+
+ let name = &variant.ident;
+
+
+ let mut variant_arguments = quote! {};
+ for field in variant.fields.iter() {
+ if let Some(name) = &field.ident {
+ variant_arguments.extend(quote! {#name: });
+ }
+ let ty = &field.ty;
+ variant_arguments.extend(quote! { <#ty as GenRandom>::gen_random(rng), });
+ }
+
+ // surround the arguments with either () or {} brackets
+ let constructor_group = match variant.fields {
+ syn::Fields::Named(_) => {
+ Some(proc_macro2::Group::new(
+ proc_macro2::Delimiter::Brace,
+ variant_arguments
+ ))
+ },
+ syn::Fields::Unnamed(_) => {
+ Some(proc_macro2::Group::new(
+ proc_macro2::Delimiter::Parenthesis,
+ variant_arguments
+ ))
+ },
+ syn::Fields::Unit => None,
+ };
+
+ function_body.extend(quote! {
+ variant -= #probability;
+ if variant <= 0.0 { return Self::#name #constructor_group; }
+ });
+ test_variant -= probability;
+
+
+ }
+
+ if test_variant >= 0.0 || test_variant < -2.0 * epsilon {
+ panic!("Probabilities for enum do not add up to 1 (final test_variant = {test_variant}).");
+ }
+
+ function_body.extend(quote! {
+ panic!("RNG returned value outside of range.")
+ });
+
+ let gen = quote! {
+ impl GenRandom for #name {
+ fn gen_random(rng: &mut impl rand::Rng) -> Self {
+ #function_body
+ }
+ }
+ };
+ //println!("{gen}");
+ gen.into()
+ },
+ _ => panic!("derive(GenRandom) can currently only be applied to enums."),
+ }
+}
+