1use std::collections::HashMap;
2
3use miden_crypto::field::Field;
4
5use super::ir::{DagId, DagSnapshot, NodeId, NodeKind};
6use crate::layout::InputKey;
7
8#[derive(Debug)]
13pub struct DagBuilder<EF> {
14 dag_id: DagId,
15 nodes: Vec<NodeKind<EF>>,
16 cache: HashMap<NodeKind<EF>, NodeId>,
17 imported_dag: Option<ImportedDag>,
18}
19
20impl<EF> DagBuilder<EF>
21where
22 EF: Field,
23{
24 pub fn new() -> Self {
26 Self {
27 dag_id: DagId::fresh(),
28 nodes: Vec::new(),
29 cache: HashMap::new(),
30 imported_dag: None,
31 }
32 }
33
34 pub fn from_nodes(nodes: Vec<NodeKind<EF>>) -> Self {
39 let imported_dag = infer_dag_id(&nodes)
40 .map(|source_dag_id| ImportedDag { source_dag_id, imported_len: nodes.len() });
41 let dag_id = DagId::fresh();
42 let nodes = rebase_nodes(nodes, dag_id);
43
44 Self::from_existing_nodes(dag_id, nodes, imported_dag)
45 }
46
47 pub fn from_snapshot(snapshot: DagSnapshot<EF>) -> Self {
51 let (source_dag_id, nodes, _) = snapshot.into_parts();
52 let dag_id = DagId::fresh();
53 let imported_dag = Some(ImportedDag { source_dag_id, imported_len: nodes.len() });
54 let nodes = rebase_nodes(nodes, dag_id);
55
56 Self::from_existing_nodes(dag_id, nodes, imported_dag)
57 }
58
59 pub fn from_dag(dag: super::AceDag<EF>) -> Self {
64 let dag_id = dag.dag_id();
65 Self::from_existing_nodes(dag_id, dag.into_nodes(), None)
66 }
67
68 fn from_existing_nodes(
69 dag_id: DagId,
70 nodes: Vec<NodeKind<EF>>,
71 imported_dag: Option<ImportedDag>,
72 ) -> Self {
73 let cache = nodes
74 .iter()
75 .enumerate()
76 .map(|(i, n)| (n.clone(), NodeId::in_dag(i, dag_id)))
77 .collect();
78 Self { dag_id, nodes, cache, imported_dag }
79 }
80
81 pub fn into_nodes(self) -> Vec<NodeKind<EF>> {
83 self.nodes
84 }
85
86 pub fn build(self, root: NodeId) -> super::AceDag<EF> {
88 let root = self.resolve_id(root, "DAG root must refer to a node built by this DagBuilder");
89
90 super::AceDag::from_parts(self.dag_id, self.nodes, root)
91 }
92
93 pub fn input(&mut self, key: InputKey) -> NodeId {
95 self.intern(NodeKind::Input(key))
96 }
97
98 pub fn constant(&mut self, value: EF) -> NodeId {
100 self.intern(NodeKind::Constant(value))
101 }
102
103 pub fn add(&mut self, a: NodeId, b: NodeId) -> NodeId {
105 let a = self.resolve_node(a);
106 let b = self.resolve_node(b);
107 if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
108 return self.constant(x + y);
109 }
110 if self.is_zero(a) {
111 return b;
112 }
113 if self.is_zero(b) {
114 return a;
115 }
116 let (l, r) = if a <= b { (a, b) } else { (b, a) };
117 self.intern(NodeKind::Add(l, r))
118 }
119
120 pub fn sub(&mut self, a: NodeId, b: NodeId) -> NodeId {
122 let a = self.resolve_node(a);
123 let b = self.resolve_node(b);
124 if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
125 return self.constant(x - y);
126 }
127 if self.is_zero(b) {
128 return a;
129 }
130 self.intern(NodeKind::Sub(a, b))
131 }
132
133 pub fn mul(&mut self, a: NodeId, b: NodeId) -> NodeId {
135 let a = self.resolve_node(a);
136 let b = self.resolve_node(b);
137 if let (Some(x), Some(y)) = (self.const_value(a), self.const_value(b)) {
138 return self.constant(x * y);
139 }
140 if self.is_zero(a) || self.is_zero(b) {
141 return self.constant(EF::ZERO);
142 }
143 if self.is_one(a) {
144 return b;
145 }
146 if self.is_one(b) {
147 return a;
148 }
149 let (l, r) = if a <= b { (a, b) } else { (b, a) };
150 self.intern(NodeKind::Mul(l, r))
151 }
152
153 pub fn neg(&mut self, a: NodeId) -> NodeId {
155 let a = self.resolve_node(a);
156 if let Some(x) = self.const_value(a) {
157 return self.constant(-x);
158 }
159 self.intern(NodeKind::Neg(a))
160 }
161
162 fn const_value(&self, id: NodeId) -> Option<EF> {
163 match self.nodes.get(id.index())? {
164 NodeKind::Constant(v) => Some(*v),
165 _ => None,
166 }
167 }
168
169 fn is_zero(&self, id: NodeId) -> bool {
170 self.const_value(id).is_some_and(|v| v == EF::ZERO)
171 }
172
173 fn is_one(&self, id: NodeId) -> bool {
174 self.const_value(id).is_some_and(|v| v == EF::ONE)
175 }
176
177 fn resolve_node(&self, id: NodeId) -> NodeId {
178 self.resolve_id(id, "DAG node must come from this DagBuilder")
179 }
180
181 fn intern(&mut self, node: NodeKind<EF>) -> NodeId {
182 if let Some(id) = self.cache.get(&node) {
183 return *id;
184 }
185 let id = NodeId::in_dag(self.nodes.len(), self.dag_id);
186 self.nodes.push(node.clone());
187 self.cache.insert(node, id);
188 id
189 }
190
191 fn resolve_id(&self, id: NodeId, message: &str) -> NodeId {
192 assert!(id.index() < self.nodes.len(), "{message}");
193
194 if id.dag_id == self.dag_id {
195 return id;
196 }
197
198 if let Some(imported) = &self.imported_dag
199 && imported.source_dag_id == id.dag_id
200 && id.index() < imported.imported_len
201 {
202 return NodeId::in_dag(id.index(), self.dag_id);
203 }
204
205 panic!("{message}");
206 }
207}
208
209fn infer_dag_id<EF>(nodes: &[NodeKind<EF>]) -> Option<DagId> {
210 nodes.iter().find_map(|node| match node {
211 NodeKind::Add(a, _) | NodeKind::Sub(a, _) | NodeKind::Mul(a, _) | NodeKind::Neg(a) => {
212 Some(a.dag_id)
213 },
214 NodeKind::Input(_) | NodeKind::Constant(_) => None,
215 })
216}
217
218fn rebase_nodes<EF>(nodes: Vec<NodeKind<EF>>, dag_id: DagId) -> Vec<NodeKind<EF>> {
219 nodes
220 .into_iter()
221 .map(|node| match node {
222 NodeKind::Input(key) => NodeKind::Input(key),
223 NodeKind::Constant(value) => NodeKind::Constant(value),
224 NodeKind::Add(a, b) => NodeKind::Add(rebase_node(a, dag_id), rebase_node(b, dag_id)),
225 NodeKind::Sub(a, b) => NodeKind::Sub(rebase_node(a, dag_id), rebase_node(b, dag_id)),
226 NodeKind::Mul(a, b) => NodeKind::Mul(rebase_node(a, dag_id), rebase_node(b, dag_id)),
227 NodeKind::Neg(a) => NodeKind::Neg(rebase_node(a, dag_id)),
228 })
229 .collect()
230}
231
232fn rebase_node(id: NodeId, dag_id: DagId) -> NodeId {
233 NodeId::in_dag(id.index(), dag_id)
234}
235
236#[derive(Debug, Clone)]
237struct ImportedDag {
238 source_dag_id: DagId,
239 imported_len: usize,
240}
241
242impl<EF> Default for DagBuilder<EF>
243where
244 EF: Field,
245{
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use miden_core::{Felt, field::QuadFelt};
254
255 use super::DagBuilder;
256 use crate::layout::InputKey;
257
258 fn felt(value: u64) -> QuadFelt {
259 QuadFelt::from(Felt::new(value))
260 }
261
262 #[test]
263 #[should_panic(expected = "DAG root must refer to a node built by this DagBuilder")]
264 fn build_rejects_same_index_root_from_another_builder() {
265 let mut foreign_builder = DagBuilder::<QuadFelt>::new();
266 let foreign_root = foreign_builder.constant(felt(1));
267
268 let mut builder = DagBuilder::<QuadFelt>::new();
269 builder.constant(felt(1));
270
271 let _ = builder.build(foreign_root);
272 }
273
274 #[test]
275 #[should_panic(expected = "DAG node must come from this DagBuilder")]
276 fn add_rejects_foreign_node() {
277 let mut foreign_builder = DagBuilder::<QuadFelt>::new();
278 let foreign = foreign_builder.constant(felt(2));
279
280 let mut builder = DagBuilder::<QuadFelt>::new();
281 let local = builder.constant(felt(1));
282
283 let _ = builder.add(local, foreign);
284 }
285
286 #[test]
287 #[should_panic(expected = "DAG node must come from this DagBuilder")]
288 fn sub_rejects_foreign_node() {
289 let mut foreign_builder = DagBuilder::<QuadFelt>::new();
290 let foreign = foreign_builder.constant(felt(2));
291
292 let mut builder = DagBuilder::<QuadFelt>::new();
293 let local = builder.constant(felt(1));
294
295 let _ = builder.sub(local, foreign);
296 }
297
298 #[test]
299 #[should_panic(expected = "DAG node must come from this DagBuilder")]
300 fn mul_rejects_foreign_node() {
301 let mut foreign_builder = DagBuilder::<QuadFelt>::new();
302 let foreign = foreign_builder.constant(felt(2));
303
304 let mut builder = DagBuilder::<QuadFelt>::new();
305 let local = builder.constant(felt(1));
306
307 let _ = builder.mul(local, foreign);
308 }
309
310 #[test]
311 #[should_panic(expected = "DAG node must come from this DagBuilder")]
312 fn neg_rejects_foreign_node() {
313 let mut foreign_builder = DagBuilder::<QuadFelt>::new();
314 let foreign = foreign_builder.constant(felt(2));
315
316 let mut builder = DagBuilder::<QuadFelt>::new();
317 let _ = builder.constant(felt(1));
318
319 let _ = builder.neg(foreign);
320 }
321
322 #[test]
323 fn from_dag_preserves_node_ownership() {
324 let mut builder = DagBuilder::<QuadFelt>::new();
325 let a = builder.constant(felt(1));
326 let dag = builder.build(a);
327 let root = dag.root();
328
329 let mut rebuilt = DagBuilder::from_dag(dag);
330 let b = rebuilt.constant(felt(2));
331 let sum = rebuilt.add(root, b);
332
333 let rebuilt_dag = rebuilt.build(sum);
334 assert_eq!(rebuilt_dag.root().index(), sum.index());
335 }
336
337 #[test]
338 fn from_nodes_accepts_published_root_shape() {
339 let mut builder = DagBuilder::<QuadFelt>::new();
340 let a = builder.input(InputKey::Gamma);
341 let b = builder.constant(felt(2));
342 let root = builder.add(a, b);
343 let dag = builder.build(root);
344
345 let mut rebuilt = DagBuilder::from_nodes(dag.nodes.clone());
346 let c = rebuilt.constant(felt(3));
347 let sum = rebuilt.add(dag.root, c);
348
349 let rebuilt_dag = rebuilt.build(sum);
350 assert_eq!(rebuilt_dag.root().index(), sum.index());
351 }
352
353 #[test]
354 fn from_nodes_accepts_leaf_only_root_shape() {
355 let mut builder = DagBuilder::<QuadFelt>::new();
356 let a = builder.constant(felt(1));
357 let dag = builder.build(a);
358
359 let root = dag.root();
360 let mut rebuilt = DagBuilder::from_snapshot(dag.into_snapshot());
361 let b = rebuilt.constant(felt(2));
362 let sum = rebuilt.add(root, b);
363
364 let rebuilt_dag = rebuilt.build(sum);
365 assert_eq!(rebuilt_dag.root().index(), sum.index());
366 }
367
368 #[test]
369 fn from_snapshot_accepts_leaf_only_root_after_source_dag_is_dropped() {
370 let mut builder = DagBuilder::<QuadFelt>::new();
371 let a = builder.constant(felt(1));
372 let snapshot = builder.build(a).into_snapshot();
373 let root = snapshot.root();
374
375 let mut rebuilt = DagBuilder::from_snapshot(snapshot);
376 let b = rebuilt.constant(felt(2));
377 let sum = rebuilt.add(root, b);
378
379 let rebuilt_dag = rebuilt.build(sum);
380 assert_eq!(rebuilt_dag.root().index(), sum.index());
381 }
382
383 #[test]
384 #[should_panic(expected = "DAG node must come from this DagBuilder")]
385 fn from_nodes_rejects_foreign_node_from_another_builder() {
386 let mut source_builder = DagBuilder::<QuadFelt>::new();
387 let a = source_builder.input(InputKey::Gamma);
388 let b = source_builder.constant(felt(2));
389 let root = source_builder.add(a, b);
390 let dag = source_builder.build(root);
391
392 let mut rebuilt = DagBuilder::from_nodes(dag.nodes.clone());
393 let mut foreign_builder = DagBuilder::<QuadFelt>::new();
394 let foreign = foreign_builder.constant(felt(3));
395
396 let _ = rebuilt.add(dag.root, foreign);
397 }
398
399 #[test]
400 #[should_panic(expected = "DAG root must refer to a node built by this DagBuilder")]
401 fn from_nodes_rejects_foreign_root_from_another_builder() {
402 let mut source_builder = DagBuilder::<QuadFelt>::new();
403 let a = source_builder.input(InputKey::Gamma);
404 let b = source_builder.constant(felt(2));
405 let root = source_builder.add(a, b);
406 let dag = source_builder.build(root);
407
408 let rebuilt = DagBuilder::from_nodes(dag.nodes.clone());
409 let mut foreign_builder = DagBuilder::<QuadFelt>::new();
410 let foreign = foreign_builder.constant(felt(3));
411
412 let _ = rebuilt.build(foreign);
413 }
414
415 #[test]
416 #[should_panic(expected = "DAG root must refer to a node built by this DagBuilder")]
417 fn from_nodes_leaf_only_rejects_foreign_root_before_any_imported_id() {
418 let mut source_builder = DagBuilder::<QuadFelt>::new();
419 let source = source_builder.constant(felt(1));
420 let dag = source_builder.build(source);
421
422 let rebuilt = DagBuilder::from_nodes(dag.nodes.clone());
423 let _ = rebuilt.build(dag.root);
424 }
425
426 #[test]
427 #[should_panic(expected = "DAG root must refer to a node built by this DagBuilder")]
428 fn from_snapshot_leaf_only_rejects_foreign_root() {
429 let mut source_builder = DagBuilder::<QuadFelt>::new();
430 let source = source_builder.constant(felt(1));
431 let snapshot = source_builder.build(source).into_snapshot();
432
433 let mut foreign_builder = DagBuilder::<QuadFelt>::new();
434 let foreign = foreign_builder.constant(felt(3));
435 let foreign_dag = foreign_builder.build(foreign);
436
437 let rebuilt = DagBuilder::from_snapshot(snapshot);
438 let _ = rebuilt.build(foreign_dag.root);
439 }
440}