Skip to main content

miden_utils_core_derive/
lib.rs

1//! Proc macro to derive enum dispatch trait implementations for Miden core utilities
2//!
3//! This crate provides proc macros for enums that need to dispatch trait method calls to their
4//! variants:
5//! - `MastNodeExt` derive macro: generates MastNodeExt trait implementations for enums
6//! - `MastForestContributor` derive macro: generates MastForestContributor trait implementations
7//!   for enums
8//!
9//! This crate provides enum dispatch functionality with:
10//! - Zero-cost enum dispatch without external dependencies
11//! - Better control over generated code
12//! - Support for complex trait patterns
13//! - Cleaner, more maintainable implementations
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use miden_utils_core_derive::MastForestContributor;
19//!
20//! #[derive(MastForestContributor)]
21//! pub enum MyEnum {
22//!     Variant1(Type1),
23//!     Variant2(Type2),
24//! }
25//! ```
26
27extern crate proc_macro;
28
29use proc_macro::TokenStream;
30use proc_macro2::Span;
31use quote::quote;
32use syn::{
33    Attribute, Data, DeriveInput, Fields, Ident, Lit, Meta, NestedMeta, Type, Variant,
34    parse_macro_input,
35};
36
37/// Derive the MastNodeExt trait for an enum.
38///
39/// This macro automatically generates implementations for all methods in the MastNodeExt trait..
40///
41/// # Attributes
42///
43/// - `#[mast_node_ext(builder = "BuilderType")]` - Specifies the builder type to use
44///
45/// # Example
46///
47/// ```rust,ignore
48/// use miden_utils_core_derive::MastNodeExt;
49///
50/// #[derive(MastNodeExt)]
51/// #[mast_node_ext(builder = "MyBuilder")]
52/// pub enum MyEnum {
53///     Variant1(Type1),
54///     Variant2(Type2),
55///     // ... other variants
56/// }
57/// ```
58#[proc_macro_derive(MastNodeExt, attributes(mast_node_ext))]
59pub fn derive_mast_node_ext(input: TokenStream) -> TokenStream {
60    let input = parse_macro_input!(input as DeriveInput);
61
62    let enum_name = &input.ident;
63    let generics = &input.generics;
64
65    // Parse the data to ensure it's an enum
66    let enum_data = match &input.data {
67        Data::Enum(data) => data,
68        _ => panic!("MastNodeExt can only be derived for enums"),
69    };
70
71    // Extract the builder type from the attribute
72    let builder_type = extract_builder_type(&input.attrs);
73
74    // Extract variant information
75    let variants: Vec<_> = enum_data.variants.iter().collect();
76    let variant_names: Vec<_> = variants.iter().map(|v| &v.ident).collect();
77    let variant_fields: Vec<_> = variants.iter().map(|v| extract_single_field(v)).collect();
78
79    // Get the list of methods to generate implementations for
80    let methods = get_mast_node_ext_methods();
81
82    let method_impls: Vec<proc_macro2::TokenStream> = methods
83        .iter()
84        .map(|method_name| {
85            generate_method_impl_for_trait_method(
86                enum_name,
87                method_name,
88                &variant_names,
89                &variant_fields,
90                &builder_type,
91            )
92        })
93        .collect();
94
95    // Build the trait implementation
96    let trait_impl = quote! {
97        impl #generics MastNodeExt for #enum_name #generics {
98            type Builder = #builder_type;
99
100            #(#method_impls)*
101        }
102    };
103
104    TokenStream::from(trait_impl)
105}
106
107fn get_mast_node_ext_methods() -> Vec<&'static str> {
108    vec![
109        "digest",
110        "before_enter",
111        "after_exit",
112        "to_display",
113        "to_pretty_print",
114        "has_children",
115        "append_children_to",
116        "for_each_child",
117        "domain",
118        "verify_node_in_forest",
119        "to_builder",
120    ]
121}
122
123/// Generate method implementation with a more compact approach
124fn generate_method_impl_for_trait_method(
125    enum_name: &Ident,
126    method_name: &str,
127    variant_names: &[&Ident],
128    variant_fields: &[Ident],
129    builder_type: &Type,
130) -> proc_macro2::TokenStream {
131    match method_name {
132        "digest" => quote! {
133            fn digest(&self) -> miden_crypto::Word {
134                match self {
135                    #(#enum_name::#variant_names(field) => field.digest()),*
136                }
137            }
138        },
139        "before_enter" => quote! {
140            fn before_enter<'a>(&'a self, forest: &'a crate::mast::MastForest) -> &'a [crate::mast::DecoratorId] {
141                match self {
142                    #(#enum_name::#variant_names(field) => field.before_enter(forest)),*
143                }
144            }
145        },
146        "after_exit" => quote! {
147            fn after_exit<'a>(&'a self, forest: &'a crate::mast::MastForest) -> &'a [crate::mast::DecoratorId] {
148                match self {
149                    #(#enum_name::#variant_names(field) => field.after_exit(forest)),*
150                }
151            }
152        },
153        "to_display" => quote! {
154            fn to_display<'a>(&'a self, mast_forest: &'a crate::mast::MastForest) -> Box<dyn core::fmt::Display + 'a> {
155                match self {
156                    #(#enum_name::#variant_names(field) => Box::new(field.to_display(mast_forest))),*
157                }
158            }
159        },
160        "to_pretty_print" => quote! {
161            fn to_pretty_print<'a>(&'a self, mast_forest: &'a crate::mast::MastForest) -> Box<dyn miden_formatting::prettier::PrettyPrint + 'a> {
162                match self {
163                    #(#enum_name::#variant_names(field) => Box::new(field.to_pretty_print(mast_forest))),*
164                }
165            }
166        },
167        "has_children" => quote! {
168            fn has_children(&self) -> bool {
169                match self {
170                    #(#enum_name::#variant_names(field) => field.has_children()),*
171                }
172            }
173        },
174        "append_children_to" => quote! {
175            fn append_children_to(&self, target: &mut alloc::vec::Vec<crate::mast::MastNodeId>) {
176                match self {
177                    #(#enum_name::#variant_names(field) => field.append_children_to(target)),*
178                }
179            }
180        },
181        "for_each_child" => quote! {
182            fn for_each_child<F>(&self, mut f: F) where F: FnMut(crate::mast::MastNodeId) {
183                match self {
184                    #(#enum_name::#variant_names(field) => field.for_each_child(f)),*
185                }
186            }
187        },
188        "domain" => quote! {
189            fn domain(&self) -> miden_crypto::Felt {
190                match self {
191                    #(#enum_name::#variant_names(field) => field.domain()),*
192                }
193            }
194        },
195        "verify_node_in_forest" => quote! {
196            #[cfg(debug_assertions)]
197            fn verify_node_in_forest(&self, forest: &crate::mast::MastForest) {
198                match self {
199                    #(#enum_name::#variant_names(field) => field.verify_node_in_forest(forest)),*
200                }
201            }
202        },
203        "to_builder" => {
204            generate_to_builder_method(enum_name, variant_names, variant_fields, builder_type)
205        },
206        _ => panic!("Unknown method: {}", method_name),
207    }
208}
209
210/// Generate to_builder method implementation
211///
212/// Contains variant name mappings for compatibility with builder types.
213fn generate_to_builder_method(
214    enum_name: &Ident,
215    variant_names: &[&Ident],
216    variant_fields: &[Ident],
217    builder_type: &Type,
218) -> proc_macro2::TokenStream {
219    let match_arms = variant_names.iter().zip(variant_fields.iter()).map(|(variant, field)| {
220        // Convert variant name to builder variant name
221        let builder_variant_name = match variant.to_string().as_str() {
222            "Block" => Ident::new("BasicBlock", Span::call_site()),
223            _ => (*variant).clone(), // Use the same name for other variants
224        };
225
226        quote! {
227            #enum_name::#variant(#field) => #builder_type::#builder_variant_name(#field.to_builder(forest))
228        }
229    });
230
231    quote! {
232        fn to_builder(self, forest: &crate::mast::MastForest) -> Self::Builder {
233            match self {
234                #(#match_arms),*
235            }
236        }
237    }
238}
239
240/// Derive trait implementations for enums that dispatch to variant trait implementations.
241///
242/// This macro generates trait implementations that forward method calls to the corresponding
243/// variant's trait implementation, similar to the `enum_dispatch` crate but without the
244/// external dependency.
245///
246/// # Example
247///
248/// ```rust,ignore
249/// use miden_utils_core_derive::MastForestContributor;
250///
251/// #[utils_core_derive(MyTrait)]
252/// #[derive(MastForestContributor)]
253/// pub enum MyEnum {
254///     Variant1(Type1),
255///     Variant2(Type2),
256/// }
257/// ```
258#[proc_macro_derive(MastForestContributor)]
259pub fn derive_mast_forest_contributor(input: TokenStream) -> TokenStream {
260    let input = parse_macro_input!(input as DeriveInput);
261
262    let enum_name = &input.ident;
263    let generics = &input.generics;
264
265    // Parse the data to ensure it's an enum
266    let enum_data = match &input.data {
267        Data::Enum(data) => data,
268        _ => panic!("EnumThispatch can only be derived for enums"),
269    };
270
271    // Extract variant information
272    let variants: Vec<_> = enum_data.variants.iter().collect();
273    let variant_names: Vec<_> = variants.iter().map(|v| &v.ident).collect();
274    let variant_fields: Vec<_> = variants.iter().map(|v| extract_single_field(v)).collect();
275
276    // Generate trait implementation by reading the trait definition
277    let trait_impl =
278        generate_mast_forest_contributor_impl(enum_name, generics, &variant_names, &variant_fields);
279
280    TokenStream::from(trait_impl)
281}
282
283/// Generate MastForestContributor trait implementation for enum dispatch
284fn generate_mast_forest_contributor_impl(
285    enum_name: &Ident,
286    generics: &syn::Generics,
287    variant_names: &[&Ident],
288    variant_fields: &[Ident],
289) -> proc_macro2::TokenStream {
290    // For now, let's generate a simple implementation to test the macro
291    let add_to_forest_arms =
292        variant_names.iter().zip(variant_fields.iter()).map(|(variant, field)| {
293            quote! {
294                #enum_name::#variant(#field) => #field.add_to_forest(forest)
295            }
296        });
297
298    quote! {
299        impl #generics crate::mast::MastForestContributor for #enum_name #generics {
300            fn add_to_forest(self, forest: &mut crate::mast::MastForest) -> Result<crate::mast::MastNodeId, crate::mast::MastForestError> {
301                match self {
302                    #(#add_to_forest_arms),*
303                }
304            }
305
306            fn fingerprint_for_node(
307                &self,
308                forest: &crate::mast::MastForest,
309                hash_by_node_id: &impl crate::utils::LookupByIdx<crate::mast::MastNodeId, crate::mast::MastNodeFingerprint>,
310            ) -> Result<crate::mast::MastNodeFingerprint, crate::mast::MastForestError> {
311                match self {
312                    #(#enum_name::#variant_names(field) => field.fingerprint_for_node(forest, hash_by_node_id)),*
313                }
314            }
315
316            fn remap_children(self, remapping: &impl crate::utils::LookupByIdx<crate::mast::MastNodeId, crate::mast::MastNodeId>) -> Self {
317                match self {
318                    #(#enum_name::#variant_names(field) => #enum_name::#variant_names(field.remap_children(remapping))),*
319                }
320            }
321
322            fn with_before_enter(self, decorators: impl Into<alloc::vec::Vec<crate::mast::DecoratorId>>) -> Self {
323                match self {
324                    #(#enum_name::#variant_names(field) => #enum_name::#variant_names(field.with_before_enter(decorators))),*
325                }
326            }
327
328            fn with_after_exit(self, decorators: impl Into<alloc::vec::Vec<crate::mast::DecoratorId>>) -> Self {
329                match self {
330                    #(#enum_name::#variant_names(field) => #enum_name::#variant_names(field.with_after_exit(decorators))),*
331                }
332            }
333
334            fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = crate::mast::DecoratorId>) {
335                match self {
336                    #(#enum_name::#variant_names(field) => field.append_before_enter(decorators)),*
337                }
338            }
339
340            fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = crate::mast::DecoratorId>) {
341                match self {
342                    #(#enum_name::#variant_names(field) => field.append_after_exit(decorators)),*
343                }
344            }
345
346            fn with_digest(self, digest: crate::Word) -> Self {
347                match self {
348                    #(#enum_name::#variant_names(field) => #enum_name::#variant_names(field.with_digest(digest))),*
349                }
350            }
351        }
352    }
353}
354
355/// Extract the builder type from the #[mast_node_ext(builder = "...")] attribute
356fn extract_builder_type(attrs: &[Attribute]) -> Type {
357    for attr in attrs {
358        if attr.path.is_ident("mast_node_ext") {
359            let meta = attr.parse_meta().expect("Failed to parse mast_node_ext attribute");
360
361            if let Meta::List(meta_list) = meta {
362                for nested in meta_list.nested {
363                    if let NestedMeta::Meta(Meta::NameValue(name_value)) = nested
364                        && name_value.path.is_ident("builder")
365                        && let Lit::Str(lit_str) = &name_value.lit
366                    {
367                        let type_str = lit_str.value();
368                        return syn::parse_str::<Type>(&type_str)
369                            .expect("Invalid builder type specification");
370                    }
371                }
372            }
373        }
374    }
375
376    panic!("Missing required attribute: #[mast_node_ext(builder = \"...\")]");
377}
378
379/// Extract the single field from a variant (e.g., BasicBlockNode from Block(BasicBlockNode))
380fn extract_single_field(variant: &Variant) -> Ident {
381    match &variant.fields {
382        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
383            // For unnamed fields, we need to create a variable name
384            // We'll use "node" as the field name in the generated code
385            Ident::new("node", Span::call_site())
386        },
387        _ => panic!(
388            "Each variant must have exactly one unnamed field, but {:?} does not",
389            variant.ident
390        ),
391    }
392}