Skip to main content

miden_ace_codegen/
circuit.rs

1//! ACE circuit emission for the DAG IR.
2//!
3//! The emitted circuit is a flat list of inputs, constants, and arithmetic
4//! ops that matches the ACE chiplet execution model.
5
6use std::collections::HashMap;
7
8use miden_crypto::field::Field;
9
10use crate::{
11    AceError, InputLayout,
12    dag::{AceDag, NodeId, NodeKind},
13};
14
15/// Arithmetic operations supported by the ACE circuit.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub(crate) enum AceOp {
18    Add,
19    Sub,
20    Mul,
21}
22
23/// Nodes in the emitted ACE circuit.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
25pub(crate) enum AceNode {
26    Input(usize),
27    Constant(usize),
28    Operation(usize),
29}
30
31/// Operation node in the ACE circuit.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub(crate) struct AceOpNode {
34    pub op: AceOp,
35    pub lhs: AceNode,
36    pub rhs: AceNode,
37}
38
39/// Emitted ACE circuit with layout and operation list.
40///
41/// This is the off-VM representation used by tests and tools.
42#[derive(Debug, Clone)]
43pub struct AceCircuit<EF> {
44    pub(crate) layout: InputLayout,
45    pub(crate) constants: Vec<EF>,
46    pub(crate) operations: Vec<AceOpNode>,
47    pub(crate) root: AceNode,
48}
49
50impl<EF: Field> AceCircuit<EF> {
51    /// Return the input layout for this circuit.
52    pub fn layout(&self) -> &InputLayout {
53        &self.layout
54    }
55
56    /// Evaluate the circuit against the provided input vector.
57    pub fn eval(&self, inputs: &[EF]) -> Result<EF, AceError> {
58        if inputs.len() != self.layout.total_inputs {
59            return Err(AceError::InvalidInputLength {
60                expected: self.layout.total_inputs,
61                got: inputs.len(),
62            });
63        }
64        let mut op_values = vec![EF::ZERO; self.operations.len()];
65        for (idx, op) in self.operations.iter().enumerate() {
66            let lhs = self.node_value(op.lhs, inputs, &op_values);
67            let rhs = self.node_value(op.rhs, inputs, &op_values);
68            op_values[idx] = match op.op {
69                AceOp::Add => lhs + rhs,
70                AceOp::Sub => lhs - rhs,
71                AceOp::Mul => lhs * rhs,
72            };
73        }
74        Ok(self.node_value(self.root, inputs, &op_values))
75    }
76
77    /// Total number of nodes (inputs + constants + ops).
78    pub fn num_nodes(&self) -> usize {
79        self.layout.total_inputs + self.constants.len() + self.operations.len()
80    }
81
82    fn node_value(&self, node: AceNode, inputs: &[EF], op_values: &[EF]) -> EF {
83        match node {
84            AceNode::Input(index) => inputs[index],
85            AceNode::Constant(index) => self.constants[index],
86            AceNode::Operation(index) => op_values[index],
87        }
88    }
89}
90
91/// Emit an ACE circuit from the DAG and input layout.
92pub fn emit_circuit<EF>(dag: &AceDag<EF>, layout: InputLayout) -> Result<AceCircuit<EF>, AceError>
93where
94    EF: Field,
95{
96    let mut constants = Vec::new();
97    let mut constant_map = HashMap::<EF, usize>::new();
98    let mut operations = Vec::new();
99    let mut node_map: Vec<Option<AceNode>> = vec![None; dag.nodes().len()];
100
101    for (idx, node) in dag.nodes().iter().enumerate() {
102        let ace_node = match node {
103            NodeKind::Input(key) => {
104                let input_idx = layout.index(*key).ok_or_else(|| AceError::InvalidInputLayout {
105                    message: format!("missing input key in layout: {key:?}"),
106                })?;
107                AceNode::Input(input_idx)
108            },
109            NodeKind::Constant(value) => {
110                let const_idx = *constant_map.entry(*value).or_insert_with(|| {
111                    constants.push(*value);
112                    constants.len() - 1
113                });
114                AceNode::Constant(const_idx)
115            },
116            NodeKind::Add(a, b) => {
117                let lhs = lookup_node(&node_map, *a);
118                let rhs = lookup_node(&node_map, *b);
119                let op_idx = operations.len();
120                operations.push(AceOpNode { op: AceOp::Add, lhs, rhs });
121                AceNode::Operation(op_idx)
122            },
123            NodeKind::Sub(a, b) => {
124                let lhs = lookup_node(&node_map, *a);
125                let rhs = lookup_node(&node_map, *b);
126                let op_idx = operations.len();
127                operations.push(AceOpNode { op: AceOp::Sub, lhs, rhs });
128                AceNode::Operation(op_idx)
129            },
130            NodeKind::Mul(a, b) => {
131                let lhs = lookup_node(&node_map, *a);
132                let rhs = lookup_node(&node_map, *b);
133                let op_idx = operations.len();
134                operations.push(AceOpNode { op: AceOp::Mul, lhs, rhs });
135                AceNode::Operation(op_idx)
136            },
137            NodeKind::Neg(a) => {
138                let rhs = lookup_node(&node_map, *a);
139                let zero = *constant_map.entry(EF::ZERO).or_insert_with(|| {
140                    constants.push(EF::ZERO);
141                    constants.len() - 1
142                });
143                let op_idx = operations.len();
144                operations.push(AceOpNode {
145                    op: AceOp::Sub,
146                    lhs: AceNode::Constant(zero),
147                    rhs,
148                });
149                AceNode::Operation(op_idx)
150            },
151        };
152        node_map[idx] = Some(ace_node);
153    }
154
155    let root = lookup_node(&node_map, dag.root());
156    Ok(AceCircuit { layout, constants, operations, root })
157}
158
159fn lookup_node(map: &[Option<AceNode>], id: NodeId) -> AceNode {
160    map[id.index()].expect("ACE DAG nodes must be topologically ordered")
161}