summaryrefslogtreecommitdiff
path: root/gen_random_proc_macro/src/lib.rs
blob: e621a76f98c05a0b46324843afebb54fd05cda97 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
extern crate quote;
extern crate proc_macro2;
extern crate syn;

use proc_macro::{TokenStream, TokenTree::{self, Punct, Literal}};
use quote::quote;

#[proc_macro_derive(GenRandom, attributes(prob))]
pub fn gen_random_derive(input: TokenStream) -> TokenStream {
	// Construct a representation of Rust code as a syntax tree
	// that we can manipulate
	let ast = syn::parse(input).unwrap();
	// Build the trait implementation
	impl_gen_random(&ast)
}


fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream {
	let name = &ast.ident;
	match &ast.data {
		syn::Data::Enum(enumeration) => {
			let variants = &enumeration.variants;
			let epsilon: f64 = 1e-9;
			let one_minus_epsilon = 1.0 - epsilon;
			let mut function_body = quote! {
				let mut variant: f64 = rng.gen_range(0.0..#one_minus_epsilon);
			};
			
			let mut test_variant = one_minus_epsilon;
			
			// parse enum fields
			for variant in variants.iter() {
				let probability: f64 = {
					let attr = variant.attrs.iter().find(|a| {
						let path = &a.path;
						if let Some(ident) = path.get_ident() {
							ident == "prob"
						} else {
							false
						}
					});
					let Some(attr) = attr else {
						panic!("Variant {} has no probability", variant.ident)
					};
					let tokens: TokenStream = attr.tokens.clone().into();
					let tokens: Vec<TokenTree> = tokens.into_iter().collect();
					if tokens.len() != 2 {
						panic!("Expected prob = <floating-point number>");
					}
					match &tokens[0] {
						Punct(equals) if equals.as_char() == '=' => {}
						_ => panic!("Expected = after prob attribute"),
					};
					
					let Literal(literal) = &tokens[1] else {
						panic!("Bad number for prob attribute.");
					};
					literal.to_string().parse().expect("Bad number for prob attribute")
				};
				
				let name = &variant.ident;
				
				
				let mut variant_arguments = quote! {};
				for field in variant.fields.iter() {
					if let Some(name) = &field.ident {
						variant_arguments.extend(quote! {#name: });
					}
					let ty = &field.ty;
					variant_arguments.extend(quote! { <#ty as GenRandom>::gen_random(rng), });
				}
				
				// surround the arguments with either () or {} brackets
				let constructor_group = match variant.fields {
					syn::Fields::Named(_) => {
						Some(proc_macro2::Group::new(
							proc_macro2::Delimiter::Brace,
							variant_arguments
						))
					},
					syn::Fields::Unnamed(_) => {
						Some(proc_macro2::Group::new(
							proc_macro2::Delimiter::Parenthesis,
							variant_arguments
						))
					},
					syn::Fields::Unit => None,
				};
				
				function_body.extend(quote! {
					variant -= #probability;
					if variant <= 0.0 { return Self::#name #constructor_group; }
				});
				test_variant -= probability;
				
				
			}
			
			if test_variant >= 0.0 || test_variant < -2.0 * epsilon {
				panic!("Probabilities for enum do not add up to 1 (final test_variant = {test_variant}).");
			}
			
			function_body.extend(quote! {
				panic!("RNG returned value outside of range.")
			});
			
			let gen = quote! {
				impl GenRandom for #name {
					fn gen_random(rng: &mut impl rand::Rng) -> Self {
						#function_body
					}
				}
			};
			//println!("{gen}");
			gen.into()
		},
		_ => panic!("derive(GenRandom) can currently only be applied to enums."),
	}
}