Skip to main content

miden_ace_codegen/dag/
builder.rs

1use std::collections::HashMap;
2
3use miden_crypto::field::Field;
4
5use super::ir::{DagId, DagSnapshot, NodeId, NodeKind};
6use crate::layout::InputKey;
7
8/// A hash-consed DAG builder.
9///
10/// The builder de-duplicates identical subexpressions to keep the circuit
11/// compact and deterministic.
12#[derive(Debug)]
13pub struct DagBuilder<EF> {
14    dag_id: DagId,
15    nodes: Vec<NodeKind<EF>>,
16    cache: HashMap<NodeKind<EF>, NodeId>,
17    imported_dag: Option<ImportedDag>,
18}
19
20impl<EF> DagBuilder<EF>
21where
22    EF: Field,
23{
24    /// Create an empty, hash-consed DAG builder.
25    pub fn new() -> Self {
26        Self {
27            dag_id: DagId::fresh(),
28            nodes: Vec::new(),
29            cache: HashMap::new(),
30            imported_dag: None,
31        }
32    }
33
34    /// Resume building from existing nodes using the published 0.23.0 API shape.
35    ///
36    /// Imported node ids are rebased onto the new builder, and ids from the source DAG
37    /// are accepted only when that provenance is encoded in the node graph itself.
38    pub fn from_nodes(nodes: Vec<NodeKind<EF>>) -> Self {
39        let imported_dag = infer_dag_id(&nodes)
40            .map(|source_dag_id| ImportedDag { source_dag_id, imported_len: nodes.len() });
41        let dag_id = DagId::fresh();
42        let nodes = rebase_nodes(nodes, dag_id);
43
44        Self::from_existing_nodes(dag_id, nodes, imported_dag)
45    }
46
47    /// Resume building from an exported snapshot.
48    ///
49    /// This preserves the original DAG id even when the imported nodes are all leaves.
50    pub fn from_snapshot(snapshot: DagSnapshot<EF>) -> Self {
51        let (source_dag_id, nodes, _) = snapshot.into_parts();
52        let dag_id = DagId::fresh();
53        let imported_dag = Some(ImportedDag { source_dag_id, imported_len: nodes.len() });
54        let nodes = rebase_nodes(nodes, dag_id);
55
56        Self::from_existing_nodes(dag_id, nodes, imported_dag)
57    }
58
59    /// Resume building from an existing DAG.
60    ///
61    /// Rebuilds the deduplication cache so that subsequent operations reuse
62    /// existing subexpressions.
63    pub fn from_dag(dag: super::AceDag<EF>) -> Self {
64        let dag_id = dag.dag_id();
65        Self::from_existing_nodes(dag_id, dag.into_nodes(), None)
66    }
67
68    fn from_existing_nodes(
69        dag_id: DagId,
70        nodes: Vec<NodeKind<EF>>,
71        imported_dag: Option<ImportedDag>,
72    ) -> Self {
73        let cache = nodes
74            .iter()
75            .enumerate()
76            .map(|(i, n)| (n.clone(), NodeId::in_dag(i, dag_id)))
77            .collect();
78        Self { dag_id, nodes, cache, imported_dag }
79    }
80
81    /// Consume the builder and return its node list.
82    pub fn into_nodes(self) -> Vec<NodeKind<EF>> {
83        self.nodes
84    }
85
86    /// Consume the builder and return a DAG with the provided root.
87    pub fn build(self, root: NodeId) -> super::AceDag<EF> {
88        let root = self.resolve_id(root, "DAG root must refer to a node built by this DagBuilder");
89
90        super::AceDag::from_parts(self.dag_id, self.nodes, root)
91    }
92
93    /// Add an input node.
94    pub fn input(&mut self, key: InputKey) -> NodeId {
95        self.intern(NodeKind::Input(key))
96    }
97
98    /// Add a constant node.
99    pub fn constant(&mut self, value: EF) -> NodeId {
100        self.intern(NodeKind::Constant(value))
101    }
102
103    /// Add an addition node (with constant folding).
104    pub fn add(&mut self, a: NodeId, b: NodeId) -> NodeId {
105        let a = self.resolve_node(a);
106        let b = self.resolve_node(b);
107        if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
108            return self.constant(x + y);
109        }
110        if self.is_zero(a) {
111            return b;
112        }
113        if self.is_zero(b) {
114            return a;
115        }
116        let (l, r) = if a <= b { (a, b) } else { (b, a) };
117        self.intern(NodeKind::Add(l, r))
118    }
119
120    /// Add a subtraction node (with constant folding).
121    pub fn sub(&mut self, a: NodeId, b: NodeId) -> NodeId {
122        let a = self.resolve_node(a);
123        let b = self.resolve_node(b);
124        if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
125            return self.constant(x - y);
126        }
127        if self.is_zero(b) {
128            return a;
129        }
130        self.intern(NodeKind::Sub(a, b))
131    }
132
133    /// Add a multiplication node (with constant folding).
134    pub fn mul(&mut self, a: NodeId, b: NodeId) -> NodeId {
135        let a = self.resolve_node(a);
136        let b = self.resolve_node(b);
137        if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
138            return self.constant(x * y);
139        }
140        if self.is_zero(a) || self.is_zero(b) {
141            return self.constant(EF::ZERO);
142        }
143        if self.is_one(a) {
144            return b;
145        }
146        if self.is_one(b) {
147            return a;
148        }
149        let (l, r) = if a <= b { (a, b) } else { (b, a) };
150        self.intern(NodeKind::Mul(l, r))
151    }
152
153    /// Add a negation node (with constant folding).
154    pub fn neg(&mut self, a: NodeId) -> NodeId {
155        let a = self.resolve_node(a);
156        if let Some(x) = self.const_value(a) {
157            return self.constant(-x);
158        }
159        self.intern(NodeKind::Neg(a))
160    }
161
162    fn const_value(&self, id: NodeId) -> Option<EF> {
163        match self.nodes.get(id.index())? {
164            NodeKind::Constant(v) => Some(*v),
165            _ => None,
166        }
167    }
168
169    fn is_zero(&self, id: NodeId) -> bool {
170        self.const_value(id).is_some_and(|v| v == EF::ZERO)
171    }
172
173    fn is_one(&self, id: NodeId) -> bool {
174        self.const_value(id).is_some_and(|v| v == EF::ONE)
175    }
176
177    fn resolve_node(&self, id: NodeId) -> NodeId {
178        self.resolve_id(id, "DAG node must come from this DagBuilder")
179    }
180
181    fn intern(&mut self, node: NodeKind<EF>) -> NodeId {
182        if let Some(id) = self.cache.get(&node) {
183            return *id;
184        }
185        let id = NodeId::in_dag(self.nodes.len(), self.dag_id);
186        self.nodes.push(node.clone());
187        self.cache.insert(node, id);
188        id
189    }
190
191    fn resolve_id(&self, id: NodeId, message: &str) -> NodeId {
192        assert!(id.index() < self.nodes.len(), "{message}");
193
194        if id.dag_id == self.dag_id {
195            return id;
196        }
197
198        if let Some(imported) = &self.imported_dag
199            && imported.source_dag_id == id.dag_id
200            && id.index() < imported.imported_len
201        {
202            return NodeId::in_dag(id.index(), self.dag_id);
203        }
204
205        panic!("{message}");
206    }
207}
208
209fn infer_dag_id<EF>(nodes: &[NodeKind<EF>]) -> Option<DagId> {
210    nodes.iter().find_map(|node| match node {
211        NodeKind::Add(a, _) | NodeKind::Sub(a, _) | NodeKind::Mul(a, _) | NodeKind::Neg(a) => {
212            Some(a.dag_id)
213        },
214        NodeKind::Input(_) | NodeKind::Constant(_) => None,
215    })
216}
217
218fn rebase_nodes<EF>(nodes: Vec<NodeKind<EF>>, dag_id: DagId) -> Vec<NodeKind<EF>> {
219    nodes
220        .into_iter()
221        .map(|node| match node {
222            NodeKind::Input(key) => NodeKind::Input(key),
223            NodeKind::Constant(value) => NodeKind::Constant(value),
224            NodeKind::Add(a, b) => NodeKind::Add(rebase_node(a, dag_id), rebase_node(b, dag_id)),
225            NodeKind::Sub(a, b) => NodeKind::Sub(rebase_node(a, dag_id), rebase_node(b, dag_id)),
226            NodeKind::Mul(a, b) => NodeKind::Mul(rebase_node(a, dag_id), rebase_node(b, dag_id)),
227            NodeKind::Neg(a) => NodeKind::Neg(rebase_node(a, dag_id)),
228        })
229        .collect()
230}
231
232fn rebase_node(id: NodeId, dag_id: DagId) -> NodeId {
233    NodeId::in_dag(id.index(), dag_id)
234}
235
236#[derive(Debug, Clone)]
237struct ImportedDag {
238    source_dag_id: DagId,
239    imported_len: usize,
240}
241
242impl<EF> Default for DagBuilder<EF>
243where
244    EF: Field,
245{
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use miden_core::{Felt, field::QuadFelt};
254
255    use super::DagBuilder;
256    use crate::layout::InputKey;
257
258    fn felt(value: u64) -> QuadFelt {
259        QuadFelt::from(Felt::new(value))
260    }
261
262    #[test]
263    #[should_panic(expected = "DAG root must refer to a node built by this DagBuilder")]
264    fn build_rejects_same_index_root_from_another_builder() {
265        let mut foreign_builder = DagBuilder::<QuadFelt>::new();
266        let foreign_root = foreign_builder.constant(felt(1));
267
268        let mut builder = DagBuilder::<QuadFelt>::new();
269        builder.constant(felt(1));
270
271        let _ = builder.build(foreign_root);
272    }
273
274    #[test]
275    #[should_panic(expected = "DAG node must come from this DagBuilder")]
276    fn add_rejects_foreign_node() {
277        let mut foreign_builder = DagBuilder::<QuadFelt>::new();
278        let foreign = foreign_builder.constant(felt(2));
279
280        let mut builder = DagBuilder::<QuadFelt>::new();
281        let local = builder.constant(felt(1));
282
283        let _ = builder.add(local, foreign);
284    }
285
286    #[test]
287    #[should_panic(expected = "DAG node must come from this DagBuilder")]
288    fn sub_rejects_foreign_node() {
289        let mut foreign_builder = DagBuilder::<QuadFelt>::new();
290        let foreign = foreign_builder.constant(felt(2));
291
292        let mut builder = DagBuilder::<QuadFelt>::new();
293        let local = builder.constant(felt(1));
294
295        let _ = builder.sub(local, foreign);
296    }
297
298    #[test]
299    #[should_panic(expected = "DAG node must come from this DagBuilder")]
300    fn mul_rejects_foreign_node() {
301        let mut foreign_builder = DagBuilder::<QuadFelt>::new();
302        let foreign = foreign_builder.constant(felt(2));
303
304        let mut builder = DagBuilder::<QuadFelt>::new();
305        let local = builder.constant(felt(1));
306
307        let _ = builder.mul(local, foreign);
308    }
309
310    #[test]
311    #[should_panic(expected = "DAG node must come from this DagBuilder")]
312    fn neg_rejects_foreign_node() {
313        let mut foreign_builder = DagBuilder::<QuadFelt>::new();
314        let foreign = foreign_builder.constant(felt(2));
315
316        let mut builder = DagBuilder::<QuadFelt>::new();
317        let _ = builder.constant(felt(1));
318
319        let _ = builder.neg(foreign);
320    }
321
322    #[test]
323    fn from_dag_preserves_node_ownership() {
324        let mut builder = DagBuilder::<QuadFelt>::new();
325        let a = builder.constant(felt(1));
326        let dag = builder.build(a);
327        let root = dag.root();
328
329        let mut rebuilt = DagBuilder::from_dag(dag);
330        let b = rebuilt.constant(felt(2));
331        let sum = rebuilt.add(root, b);
332
333        let rebuilt_dag = rebuilt.build(sum);
334        assert_eq!(rebuilt_dag.root().index(), sum.index());
335    }
336
337    #[test]
338    fn from_nodes_accepts_published_root_shape() {
339        let mut builder = DagBuilder::<QuadFelt>::new();
340        let a = builder.input(InputKey::Gamma);
341        let b = builder.constant(felt(2));
342        let root = builder.add(a, b);
343        let dag = builder.build(root);
344
345        let mut rebuilt = DagBuilder::from_nodes(dag.nodes.clone());
346        let c = rebuilt.constant(felt(3));
347        let sum = rebuilt.add(dag.root, c);
348
349        let rebuilt_dag = rebuilt.build(sum);
350        assert_eq!(rebuilt_dag.root().index(), sum.index());
351    }
352
353    #[test]
354    fn from_nodes_accepts_leaf_only_root_shape() {
355        let mut builder = DagBuilder::<QuadFelt>::new();
356        let a = builder.constant(felt(1));
357        let dag = builder.build(a);
358
359        let root = dag.root();
360        let mut rebuilt = DagBuilder::from_snapshot(dag.into_snapshot());
361        let b = rebuilt.constant(felt(2));
362        let sum = rebuilt.add(root, b);
363
364        let rebuilt_dag = rebuilt.build(sum);
365        assert_eq!(rebuilt_dag.root().index(), sum.index());
366    }
367
368    #[test]
369    fn from_snapshot_accepts_leaf_only_root_after_source_dag_is_dropped() {
370        let mut builder = DagBuilder::<QuadFelt>::new();
371        let a = builder.constant(felt(1));
372        let snapshot = builder.build(a).into_snapshot();
373        let root = snapshot.root();
374
375        let mut rebuilt = DagBuilder::from_snapshot(snapshot);
376        let b = rebuilt.constant(felt(2));
377        let sum = rebuilt.add(root, b);
378
379        let rebuilt_dag = rebuilt.build(sum);
380        assert_eq!(rebuilt_dag.root().index(), sum.index());
381    }
382
383    #[test]
384    #[should_panic(expected = "DAG node must come from this DagBuilder")]
385    fn from_nodes_rejects_foreign_node_from_another_builder() {
386        let mut source_builder = DagBuilder::<QuadFelt>::new();
387        let a = source_builder.input(InputKey::Gamma);
388        let b = source_builder.constant(felt(2));
389        let root = source_builder.add(a, b);
390        let dag = source_builder.build(root);
391
392        let mut rebuilt = DagBuilder::from_nodes(dag.nodes.clone());
393        let mut foreign_builder = DagBuilder::<QuadFelt>::new();
394        let foreign = foreign_builder.constant(felt(3));
395
396        let _ = rebuilt.add(dag.root, foreign);
397    }
398
399    #[test]
400    #[should_panic(expected = "DAG root must refer to a node built by this DagBuilder")]
401    fn from_nodes_rejects_foreign_root_from_another_builder() {
402        let mut source_builder = DagBuilder::<QuadFelt>::new();
403        let a = source_builder.input(InputKey::Gamma);
404        let b = source_builder.constant(felt(2));
405        let root = source_builder.add(a, b);
406        let dag = source_builder.build(root);
407
408        let rebuilt = DagBuilder::from_nodes(dag.nodes.clone());
409        let mut foreign_builder = DagBuilder::<QuadFelt>::new();
410        let foreign = foreign_builder.constant(felt(3));
411
412        let _ = rebuilt.build(foreign);
413    }
414
415    #[test]
416    #[should_panic(expected = "DAG root must refer to a node built by this DagBuilder")]
417    fn from_nodes_leaf_only_rejects_foreign_root_before_any_imported_id() {
418        let mut source_builder = DagBuilder::<QuadFelt>::new();
419        let source = source_builder.constant(felt(1));
420        let dag = source_builder.build(source);
421
422        let rebuilt = DagBuilder::from_nodes(dag.nodes.clone());
423        let _ = rebuilt.build(dag.root);
424    }
425
426    #[test]
427    #[should_panic(expected = "DAG root must refer to a node built by this DagBuilder")]
428    fn from_snapshot_leaf_only_rejects_foreign_root() {
429        let mut source_builder = DagBuilder::<QuadFelt>::new();
430        let source = source_builder.constant(felt(1));
431        let snapshot = source_builder.build(source).into_snapshot();
432
433        let mut foreign_builder = DagBuilder::<QuadFelt>::new();
434        let foreign = foreign_builder.constant(felt(3));
435        let foreign_dag = foreign_builder.build(foreign);
436
437        let rebuilt = DagBuilder::from_snapshot(snapshot);
438        let _ = rebuilt.build(foreign_dag.root);
439    }
440}