diff options
author | pommicket <pommicket@gmail.com> | 2022-12-14 14:58:57 -0500 |
---|---|---|
committer | pommicket <pommicket@gmail.com> | 2022-12-14 14:58:57 -0500 |
commit | e8bc993ed558e1b25a31e6a6fabac7853e1b1035 (patch) | |
tree | aaec434df223e58e84c38799a6870e0902b17c6a | |
parent | 5a20cffba66caa71b495736f75031f69d09ba40b (diff) |
GenRandom structs, scale, bias
-rw-r--r-- | gen_random_proc_macro/src/lib.rs | 170 | ||||
-rw-r--r-- | gen_random_test/src/lib.rs | 24 | ||||
-rw-r--r-- | src/main.rs | 9 | ||||
-rw-r--r-- | src/sdf.rs | 6 |
4 files changed, 134 insertions, 75 deletions
diff --git a/gen_random_proc_macro/src/lib.rs b/gen_random_proc_macro/src/lib.rs index e621a76..1a620c3 100644 --- a/gen_random_proc_macro/src/lib.rs +++ b/gen_random_proc_macro/src/lib.rs @@ -2,10 +2,14 @@ extern crate quote; extern crate proc_macro2; extern crate syn; -use proc_macro::{TokenStream, TokenTree::{self, Punct, Literal}}; +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::TokenTree as TokenTree2; +use std::str::FromStr; +use std::any::type_name; use quote::quote; -#[proc_macro_derive(GenRandom, attributes(prob))] +#[proc_macro_derive(GenRandom, attributes(prob, scale, bias))] pub fn gen_random_derive(input: TokenStream) -> TokenStream { // Construct a representation of Rust code as a syntax tree // that we can manipulate @@ -14,15 +18,88 @@ pub fn gen_random_derive(input: TokenStream) -> TokenStream { impl_gen_random(&ast) } +fn get_attribute_literal(attrs: &[syn::Attribute], name: &str) -> Option<proc_macro2::Literal> { + let attr = attrs.iter().find(|a| { + let path = &a.path; + if let Some(ident) = path.get_ident() { + ident == name + } else { + false + } + })?; + + let tokens: TokenStream2 = attr.tokens.clone().into(); + let mut tokens: Vec<TokenTree2> = tokens.into_iter().collect(); + if tokens.len() != 2 { + panic!("Expected {name} = <value>"); + } + use TokenTree2::{Punct, Literal}; + match &tokens[0] { + Punct(equals) if equals.as_char() == '=' => {} + _ => panic!("Expected = after {name} attribute"), + }; + + let Literal(literal) = tokens.remove(1) else { + panic!("Bad value for {name} attribute."); + }; + Some(literal) +} + +fn parse_attribute_value<T: FromStr>(attrs: &[syn::Attribute], name: &str) -> Option<T> { + let literal = get_attribute_literal(attrs, name)?; + let Ok(value) = literal.to_string().parse() else { + panic!("Bad {} for {name} attribute", type_name::<T>()) + }; + Some(value) +} + +fn generate_fields(fields: &syn::Fields) -> impl quote::ToTokens { + let mut field_values = quote! {}; + for field in fields.iter() { + if let Some(name) = &field.ident { + field_values.extend(quote! {#name: }); + } + let ty = &field.ty; + field_values.extend(quote! { <#ty as GenRandom>::gen_random(rng) }); + + if let Some(scale) = get_attribute_literal(&field.attrs, "scale") { + field_values.extend(quote! { * #scale }); + } + if let Some(bias) = get_attribute_literal(&field.attrs, "bias") { + field_values.extend(quote! { + #bias }); + } + + field_values.extend(quote! { , }); + } + + // surround the field values with either () or {} brackets + match fields { + syn::Fields::Named(_) => { + Some(proc_macro2::Group::new( + proc_macro2::Delimiter::Brace, + field_values + )) + }, + syn::Fields::Unnamed(_) => { + Some(proc_macro2::Group::new( + proc_macro2::Delimiter::Parenthesis, + field_values + )) + }, + syn::Fields::Unit => None, + } +} fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { let name = &ast.ident; + let mut function_body; + match &ast.data { syn::Data::Enum(enumeration) => { let variants = &enumeration.variants; let epsilon: f64 = 1e-9; let one_minus_epsilon = 1.0 - epsilon; - let mut function_body = quote! { + function_body = quote! { let mut variant: f64 = rng.gen_range(0.0..#one_minus_epsilon); }; @@ -30,66 +107,17 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { // parse enum fields for variant in variants.iter() { - let probability: f64 = { - let attr = variant.attrs.iter().find(|a| { - let path = &a.path; - if let Some(ident) = path.get_ident() { - ident == "prob" - } else { - false - } - }); - let Some(attr) = attr else { - panic!("Variant {} has no probability", variant.ident) - }; - let tokens: TokenStream = attr.tokens.clone().into(); - let tokens: Vec<TokenTree> = tokens.into_iter().collect(); - if tokens.len() != 2 { - panic!("Expected prob = <floating-point number>"); - } - match &tokens[0] { - Punct(equals) if equals.as_char() == '=' => {} - _ => panic!("Expected = after prob attribute"), - }; - - let Literal(literal) = &tokens[1] else { - panic!("Bad number for prob attribute."); - }; - literal.to_string().parse().expect("Bad number for prob attribute") + let probability: Option<f64> = parse_attribute_value(&variant.attrs, "prob"); + let Some(probability) = probability else { + panic!("Variant {} has no probability", variant.ident) }; let name = &variant.ident; - - - let mut variant_arguments = quote! {}; - for field in variant.fields.iter() { - if let Some(name) = &field.ident { - variant_arguments.extend(quote! {#name: }); - } - let ty = &field.ty; - variant_arguments.extend(quote! { <#ty as GenRandom>::gen_random(rng), }); - } - - // surround the arguments with either () or {} brackets - let constructor_group = match variant.fields { - syn::Fields::Named(_) => { - Some(proc_macro2::Group::new( - proc_macro2::Delimiter::Brace, - variant_arguments - )) - }, - syn::Fields::Unnamed(_) => { - Some(proc_macro2::Group::new( - proc_macro2::Delimiter::Parenthesis, - variant_arguments - )) - }, - syn::Fields::Unit => None, - }; + let field_values = generate_fields(&variant.fields); function_body.extend(quote! { variant -= #probability; - if variant <= 0.0 { return Self::#name #constructor_group; } + if variant <= 0.0 { return Self::#name #field_values; } }); test_variant -= probability; @@ -104,17 +132,25 @@ fn impl_gen_random(ast: &syn::DeriveInput) -> TokenStream { panic!("RNG returned value outside of range.") }); - let gen = quote! { - impl GenRandom for #name { - fn gen_random(rng: &mut impl rand::Rng) -> Self { - #function_body - } - } + }, + syn::Data::Struct(structure) => { + let field_values = generate_fields(&structure.fields); + function_body = quote! { + Self #field_values }; - //println!("{gen}"); - gen.into() }, - _ => panic!("derive(GenRandom) can currently only be applied to enums."), - } + syn::Data::Union(_) => panic!("derive(GenRandom) cannot be applied to unions."), + }; + + let gen = quote! { + impl GenRandom for #name { + fn gen_random(rng: &mut impl rand::Rng) -> Self { + #function_body + } + } + }; + + //println!("{gen}"); + gen.into() } diff --git a/gen_random_test/src/lib.rs b/gen_random_test/src/lib.rs index 00f5026..bea6190 100644 --- a/gen_random_test/src/lib.rs +++ b/gen_random_test/src/lib.rs @@ -33,6 +33,16 @@ mod tests { #[prob = 0.9] Cons(f32, Box<LinkedList>) } + + #[derive(GenRandom, Debug)] + struct ScaleBias { + #[bias = 1.0] + #[scale = 10.0] + a: f32, + #[bias = 2.0] + #[scale = 0.0] + b: f32, + } #[test] fn basic() { @@ -51,4 +61,18 @@ mod tests { let ll = LinkedList::gen_thread_random(); println!("{ll:?}"); } + + #[test] + fn scale_bias() { + let sb: Vec<ScaleBias> = gen_thread_random_vec(10); + println!("{sb:?}"); + for x in sb.iter() { + if x.a < 1.0 || x.a > 11.0 { + panic!("a field should be between 1 and 11; got {}", x.a); + } + if x.b != 2.0 { + panic!("b field should be exactly 2; got {}", x.b); + } + } + } } diff --git a/src/main.rs b/src/main.rs index ab2aa40..d1afbce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ /* @TODO: - use 0..(sum of probs) for variant -- scale and bias - fullscreen key - mathematical analysis - options for: @@ -10,19 +9,19 @@ - AA quality - # iterations, distance cutoff - documentation +- GenRandom integers (+ gen_random_scale_bias) */ -extern crate nalgebra; extern crate gen_random; +extern crate nalgebra; pub mod sdf; mod sdl; pub mod win; +use gen_random::GenRandom; use nalgebra::{Matrix3, Matrix4, Rotation3, Vector3}; use std::time::Instant; -use gen_random::GenRandom; - type Vec3 = Vector3<f32>; type Mat3 = Matrix3<f32>; @@ -70,7 +69,7 @@ fn try_main() -> Result<(), String> { use sdf::{Constant, R3ToR, R3ToR3, RToR}; let _test = Constant::gen_thread_random(); println!("{_test:?}"); - + let funciton = R3ToR::compose( R3ToR3::InfiniteMirrors(Constant::from(2.0)), R3ToR::sphere_f32(0.2), @@ -1,10 +1,10 @@ #![allow(dead_code)] // @TODO @TEMPORARY -extern crate rand; extern crate gen_random_proc_macro; +extern crate rand; -use std::fmt::{self, Display, Formatter, Write}; -use gen_random_proc_macro::GenRandom; use gen_random::GenRandom; +use gen_random_proc_macro::GenRandom; +use std::fmt::{self, Display, Formatter, Write}; // we're only writing numbers and strings so write! should never fail. macro_rules! write_str { |