1extern crate proc_macro;
28
29use proc_macro::TokenStream;
30use proc_macro2::Span;
31use quote::quote;
32use syn::{
33 Attribute, Data, DeriveInput, Fields, Ident, Lit, Meta, NestedMeta, Type, Variant,
34 parse_macro_input,
35};
36
37#[proc_macro_derive(MastNodeExt, attributes(mast_node_ext))]
59pub fn derive_mast_node_ext(input: TokenStream) -> TokenStream {
60 let input = parse_macro_input!(input as DeriveInput);
61
62 let enum_name = &input.ident;
63 let generics = &input.generics;
64
65 let enum_data = match &input.data {
67 Data::Enum(data) => data,
68 _ => panic!("MastNodeExt can only be derived for enums"),
69 };
70
71 let builder_type = extract_builder_type(&input.attrs);
73
74 let variants: Vec<_> = enum_data.variants.iter().collect();
76 let variant_names: Vec<_> = variants.iter().map(|v| &v.ident).collect();
77 let variant_fields: Vec<_> = variants.iter().map(|v| extract_single_field(v)).collect();
78
79 let methods = get_mast_node_ext_methods();
81
82 let method_impls: Vec<proc_macro2::TokenStream> = methods
83 .iter()
84 .map(|method_name| {
85 generate_method_impl_for_trait_method(
86 enum_name,
87 method_name,
88 &variant_names,
89 &variant_fields,
90 &builder_type,
91 )
92 })
93 .collect();
94
95 let trait_impl = quote! {
97 impl #generics MastNodeExt for #enum_name #generics {
98 type Builder = #builder_type;
99
100 #(#method_impls)*
101 }
102 };
103
104 TokenStream::from(trait_impl)
105}
106
107fn get_mast_node_ext_methods() -> Vec<&'static str> {
108 vec![
109 "digest",
110 "before_enter",
111 "after_exit",
112 "to_display",
113 "to_pretty_print",
114 "has_children",
115 "append_children_to",
116 "for_each_child",
117 "domain",
118 "verify_node_in_forest",
119 "to_builder",
120 ]
121}
122
123fn generate_method_impl_for_trait_method(
125 enum_name: &Ident,
126 method_name: &str,
127 variant_names: &[&Ident],
128 variant_fields: &[Ident],
129 builder_type: &Type,
130) -> proc_macro2::TokenStream {
131 match method_name {
132 "digest" => quote! {
133 fn digest(&self) -> miden_crypto::Word {
134 match self {
135 #(#enum_name::#variant_names(field) => field.digest()),*
136 }
137 }
138 },
139 "before_enter" => quote! {
140 fn before_enter<'a>(&'a self, forest: &'a crate::mast::MastForest) -> &'a [crate::mast::DecoratorId] {
141 match self {
142 #(#enum_name::#variant_names(field) => field.before_enter(forest)),*
143 }
144 }
145 },
146 "after_exit" => quote! {
147 fn after_exit<'a>(&'a self, forest: &'a crate::mast::MastForest) -> &'a [crate::mast::DecoratorId] {
148 match self {
149 #(#enum_name::#variant_names(field) => field.after_exit(forest)),*
150 }
151 }
152 },
153 "to_display" => quote! {
154 fn to_display<'a>(&'a self, mast_forest: &'a crate::mast::MastForest) -> Box<dyn core::fmt::Display + 'a> {
155 match self {
156 #(#enum_name::#variant_names(field) => Box::new(field.to_display(mast_forest))),*
157 }
158 }
159 },
160 "to_pretty_print" => quote! {
161 fn to_pretty_print<'a>(&'a self, mast_forest: &'a crate::mast::MastForest) -> Box<dyn miden_formatting::prettier::PrettyPrint + 'a> {
162 match self {
163 #(#enum_name::#variant_names(field) => Box::new(field.to_pretty_print(mast_forest))),*
164 }
165 }
166 },
167 "has_children" => quote! {
168 fn has_children(&self) -> bool {
169 match self {
170 #(#enum_name::#variant_names(field) => field.has_children()),*
171 }
172 }
173 },
174 "append_children_to" => quote! {
175 fn append_children_to(&self, target: &mut alloc::vec::Vec<crate::mast::MastNodeId>) {
176 match self {
177 #(#enum_name::#variant_names(field) => field.append_children_to(target)),*
178 }
179 }
180 },
181 "for_each_child" => quote! {
182 fn for_each_child<F>(&self, mut f: F) where F: FnMut(crate::mast::MastNodeId) {
183 match self {
184 #(#enum_name::#variant_names(field) => field.for_each_child(f)),*
185 }
186 }
187 },
188 "domain" => quote! {
189 fn domain(&self) -> miden_crypto::Felt {
190 match self {
191 #(#enum_name::#variant_names(field) => field.domain()),*
192 }
193 }
194 },
195 "verify_node_in_forest" => quote! {
196 #[cfg(debug_assertions)]
197 fn verify_node_in_forest(&self, forest: &crate::mast::MastForest) {
198 match self {
199 #(#enum_name::#variant_names(field) => field.verify_node_in_forest(forest)),*
200 }
201 }
202 },
203 "to_builder" => {
204 generate_to_builder_method(enum_name, variant_names, variant_fields, builder_type)
205 },
206 _ => panic!("Unknown method: {}", method_name),
207 }
208}
209
210fn generate_to_builder_method(
214 enum_name: &Ident,
215 variant_names: &[&Ident],
216 variant_fields: &[Ident],
217 builder_type: &Type,
218) -> proc_macro2::TokenStream {
219 let match_arms = variant_names.iter().zip(variant_fields.iter()).map(|(variant, field)| {
220 let builder_variant_name = match variant.to_string().as_str() {
222 "Block" => Ident::new("BasicBlock", Span::call_site()),
223 _ => (*variant).clone(), };
225
226 quote! {
227 #enum_name::#variant(#field) => #builder_type::#builder_variant_name(#field.to_builder(forest))
228 }
229 });
230
231 quote! {
232 fn to_builder(self, forest: &crate::mast::MastForest) -> Self::Builder {
233 match self {
234 #(#match_arms),*
235 }
236 }
237 }
238}
239
240#[proc_macro_derive(MastForestContributor)]
259pub fn derive_mast_forest_contributor(input: TokenStream) -> TokenStream {
260 let input = parse_macro_input!(input as DeriveInput);
261
262 let enum_name = &input.ident;
263 let generics = &input.generics;
264
265 let enum_data = match &input.data {
267 Data::Enum(data) => data,
268 _ => panic!("EnumThispatch can only be derived for enums"),
269 };
270
271 let variants: Vec<_> = enum_data.variants.iter().collect();
273 let variant_names: Vec<_> = variants.iter().map(|v| &v.ident).collect();
274 let variant_fields: Vec<_> = variants.iter().map(|v| extract_single_field(v)).collect();
275
276 let trait_impl =
278 generate_mast_forest_contributor_impl(enum_name, generics, &variant_names, &variant_fields);
279
280 TokenStream::from(trait_impl)
281}
282
283fn generate_mast_forest_contributor_impl(
285 enum_name: &Ident,
286 generics: &syn::Generics,
287 variant_names: &[&Ident],
288 variant_fields: &[Ident],
289) -> proc_macro2::TokenStream {
290 let add_to_forest_arms =
292 variant_names.iter().zip(variant_fields.iter()).map(|(variant, field)| {
293 quote! {
294 #enum_name::#variant(#field) => #field.add_to_forest(forest)
295 }
296 });
297
298 quote! {
299 impl #generics crate::mast::MastForestContributor for #enum_name #generics {
300 fn add_to_forest(self, forest: &mut crate::mast::MastForest) -> Result<crate::mast::MastNodeId, crate::mast::MastForestError> {
301 match self {
302 #(#add_to_forest_arms),*
303 }
304 }
305
306 fn fingerprint_for_node(
307 &self,
308 forest: &crate::mast::MastForest,
309 hash_by_node_id: &impl crate::utils::LookupByIdx<crate::mast::MastNodeId, crate::mast::MastNodeFingerprint>,
310 ) -> Result<crate::mast::MastNodeFingerprint, crate::mast::MastForestError> {
311 match self {
312 #(#enum_name::#variant_names(field) => field.fingerprint_for_node(forest, hash_by_node_id)),*
313 }
314 }
315
316 fn remap_children(self, remapping: &impl crate::utils::LookupByIdx<crate::mast::MastNodeId, crate::mast::MastNodeId>) -> Self {
317 match self {
318 #(#enum_name::#variant_names(field) => #enum_name::#variant_names(field.remap_children(remapping))),*
319 }
320 }
321
322 fn with_before_enter(self, decorators: impl Into<alloc::vec::Vec<crate::mast::DecoratorId>>) -> Self {
323 match self {
324 #(#enum_name::#variant_names(field) => #enum_name::#variant_names(field.with_before_enter(decorators))),*
325 }
326 }
327
328 fn with_after_exit(self, decorators: impl Into<alloc::vec::Vec<crate::mast::DecoratorId>>) -> Self {
329 match self {
330 #(#enum_name::#variant_names(field) => #enum_name::#variant_names(field.with_after_exit(decorators))),*
331 }
332 }
333
334 fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = crate::mast::DecoratorId>) {
335 match self {
336 #(#enum_name::#variant_names(field) => field.append_before_enter(decorators)),*
337 }
338 }
339
340 fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = crate::mast::DecoratorId>) {
341 match self {
342 #(#enum_name::#variant_names(field) => field.append_after_exit(decorators)),*
343 }
344 }
345
346 fn with_digest(self, digest: crate::Word) -> Self {
347 match self {
348 #(#enum_name::#variant_names(field) => #enum_name::#variant_names(field.with_digest(digest))),*
349 }
350 }
351 }
352 }
353}
354
355fn extract_builder_type(attrs: &[Attribute]) -> Type {
357 for attr in attrs {
358 if attr.path.is_ident("mast_node_ext") {
359 let meta = attr.parse_meta().expect("Failed to parse mast_node_ext attribute");
360
361 if let Meta::List(meta_list) = meta {
362 for nested in meta_list.nested {
363 if let NestedMeta::Meta(Meta::NameValue(name_value)) = nested
364 && name_value.path.is_ident("builder")
365 && let Lit::Str(lit_str) = &name_value.lit
366 {
367 let type_str = lit_str.value();
368 return syn::parse_str::<Type>(&type_str)
369 .expect("Invalid builder type specification");
370 }
371 }
372 }
373 }
374 }
375
376 panic!("Missing required attribute: #[mast_node_ext(builder = \"...\")]");
377}
378
379fn extract_single_field(variant: &Variant) -> Ident {
381 match &variant.fields {
382 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
383 Ident::new("node", Span::call_site())
386 },
387 _ => panic!(
388 "Each variant must have exactly one unnamed field, but {:?} does not",
389 variant.ident
390 ),
391 }
392}