Skip to main content

miden_air/
ace.rs

1//! ACE circuit integration for ProcessorAir.
2//!
3//! This module contains:
4//! - Batching types and functions (`MessageElement`, `ReducedAuxBatchConfig`,
5//!   `batch_reduced_aux_values`) that extend a constraint check DAG with the auxiliary trace
6//!   boundary checks.
7//! - The AIR-specific `reduced_aux_batch_config()` that describes the Miden VM's auxiliary trace
8//!   boundary checks.
9//! - The convenience function `build_batched_ace_circuit()` that builds the full batched circuit in
10//!   one call.
11//!
12//! The formula checked by the batched circuit is:
13//!   `constraint_check + gamma * product_check + gamma^2 * sum_check = 0`
14
15use alloc::vec::Vec;
16
17use miden_ace_codegen::{
18    AceCircuit, AceConfig, AceDag, AceError, DagBuilder, InputKey, NodeId, build_ace_dag_for_air,
19};
20use miden_core::{Felt, field::ExtensionField};
21use miden_crypto::{
22    field::Algebra,
23    stark::air::{LiftedAir, symbolic::SymbolicExpressionExt},
24};
25
26use crate::trace;
27
28// BATCHING TYPES
29// ================================================================================================
30
31/// An element in a bus message encoding.
32///
33/// Bus messages are encoded as `alpha + sum(beta^i * elements[i])` using the
34/// aux randomness challenges. Each element is either a constant base-field value
35/// or a reference to a fixed-length public input.
36#[derive(Debug, Clone)]
37pub enum MessageElement {
38    /// A constant base-field value.
39    Constant(Felt),
40    /// A fixed-length public input value, indexed into the public values array.
41    PublicInput(usize),
42}
43
44/// A multiplicative factor in the product check.
45///
46/// The product check verifies:
47///   `product(numerator) - product(denominator) = 0`
48#[derive(Debug, Clone)]
49pub enum ProductFactor {
50    /// Claimed final value of an auxiliary trace column, by column index.
51    BusBoundary(usize),
52    /// A bus message computed from its elements as `alpha + sum(beta^i * elements[i])`.
53    Message(Vec<MessageElement>),
54    /// Multiset product reduced from variable-length public inputs, by group index.
55    Vlpi(usize),
56}
57
58/// Configuration for building the reduced_aux_values batching in the ACE DAG.
59///
60/// Describes the auxiliary trace boundary checks (product_check and sum_check).
61/// Constructed by AIR-specific code (see [`reduced_aux_batch_config`]) and
62/// consumed by [`batch_reduced_aux_values`].
63///
64/// The product check verifies:
65///   `product(numerator) - product(denominator) = 0`
66#[derive(Debug, Clone)]
67pub struct ReducedAuxBatchConfig {
68    /// Factors multiplied into the numerator of the product check.
69    pub numerator: Vec<ProductFactor>,
70    /// Factors multiplied into the denominator of the product check.
71    pub denominator: Vec<ProductFactor>,
72    /// Auxiliary trace column indices whose claimed final values are summed in the sum check.
73    pub sum_columns: Vec<usize>,
74}
75
76// BATCHING FUNCTIONS
77// ================================================================================================
78
79/// Extend an existing constraint DAG with auxiliary trace boundary checks.
80///
81/// Takes the constraint DAG and appends the running-product identity check
82/// (product_check) and the LogUp sum check (sum_check), combining all three
83/// checks into a single root with gamma:
84///
85///   `root = constraint_check + gamma * product_check + gamma^2 * sum_check`
86///
87/// Returns the new DAG with the batched root.
88pub fn batch_reduced_aux_values<EF>(
89    constraint_dag: AceDag<EF>,
90    config: &ReducedAuxBatchConfig,
91) -> AceDag<EF>
92where
93    EF: ExtensionField<Felt>,
94{
95    let constraint_root = constraint_dag.root;
96    let mut builder = DagBuilder::from_dag(constraint_dag);
97
98    // Build product_check.
99    let product_check = build_product_check(&mut builder, config);
100
101    // Build sum_check.
102    let sum_check = build_sum_check(&mut builder, config);
103
104    // Batch: root = constraint_check + gamma * product_check + gamma^2 * sum_check
105    let gamma = builder.input(InputKey::Gamma);
106    let gamma2 = builder.mul(gamma, gamma);
107    let term2 = builder.mul(gamma, product_check);
108    let term3 = builder.mul(gamma2, sum_check);
109    let partial = builder.add(constraint_root, term2);
110    let root = builder.add(partial, term3);
111
112    builder.build(root)
113}
114
115/// Build the running-product identity check.
116fn build_product_check<EF>(builder: &mut DagBuilder<EF>, config: &ReducedAuxBatchConfig) -> NodeId
117where
118    EF: ExtensionField<Felt>,
119{
120    let numerator = build_product(builder, &config.numerator);
121    let denominator = build_product(builder, &config.denominator);
122    builder.sub(numerator, denominator)
123}
124
125/// Build a product of factors as a single DAG node.
126fn build_product<EF>(builder: &mut DagBuilder<EF>, factors: &[ProductFactor]) -> NodeId
127where
128    EF: ExtensionField<Felt>,
129{
130    let mut acc = builder.constant(EF::ONE);
131    for factor in factors {
132        let node = match factor {
133            ProductFactor::BusBoundary(idx) => builder.input(InputKey::AuxBusBoundary(*idx)),
134            ProductFactor::Message(elements) => encode_bus_message(builder, elements),
135            ProductFactor::Vlpi(idx) => builder.input(InputKey::VlpiReduction(*idx)),
136        };
137        acc = builder.mul(acc, node);
138    }
139    acc
140}
141
142/// Build the LogUp sum check (sum_check).
143///
144/// Verifies that the LogUp auxiliary columns sum to zero at the boundary.
145fn build_sum_check<EF>(builder: &mut DagBuilder<EF>, config: &ReducedAuxBatchConfig) -> NodeId
146where
147    EF: ExtensionField<Felt>,
148{
149    let mut sum = builder.constant(EF::ZERO);
150    for &col_idx in &config.sum_columns {
151        let col = builder.input(InputKey::AuxBusBoundary(col_idx));
152        sum = builder.add(sum, col);
153    }
154    sum
155}
156
157/// Encode a bus message as `alpha + sum(beta^i * elements[i])`.
158fn encode_bus_message<EF>(builder: &mut DagBuilder<EF>, elements: &[MessageElement]) -> NodeId
159where
160    EF: ExtensionField<Felt>,
161{
162    let alpha = builder.input(InputKey::AuxRandAlpha);
163    let beta = builder.input(InputKey::AuxRandBeta);
164
165    // acc = alpha + sum(beta^i * elem_i)
166    //
167    // Beta powers are built incrementally. The DagBuilder is hash-consed, so
168    // identical beta^i nodes across multiple message encodings are shared
169    // automatically.
170    let mut acc = alpha;
171    let mut beta_power = builder.constant(EF::ONE);
172    for elem in elements {
173        let node = match elem {
174            MessageElement::Constant(f) => builder.constant(EF::from(*f)),
175            MessageElement::PublicInput(idx) => builder.input(InputKey::Public(*idx)),
176        };
177        let term = builder.mul(beta_power, node);
178        acc = builder.add(acc, term);
179        beta_power = builder.mul(beta_power, beta);
180    }
181    acc
182}
183
184// AIR-SPECIFIC CONFIG
185// ================================================================================================
186
187/// Build the [`ReducedAuxBatchConfig`] for the Miden VM ProcessorAir.
188///
189/// This encodes the `reduced_aux_values` formula in the Miden VM AIR.
190pub fn reduced_aux_batch_config() -> ReducedAuxBatchConfig {
191    use MessageElement::{Constant, PublicInput};
192    use ProductFactor::{BusBoundary, Message, Vlpi};
193
194    // Aux boundary column indices.
195    let p1 = trace::DECODER_AUX_TRACE_OFFSET;
196    let p2 = trace::DECODER_AUX_TRACE_OFFSET + 1;
197    let p3 = trace::DECODER_AUX_TRACE_OFFSET + 2;
198    let s_aux = trace::STACK_AUX_TRACE_OFFSET;
199    let b_range = trace::RANGE_CHECK_AUX_TRACE_OFFSET;
200    let b_hash_kernel = trace::HASH_KERNEL_VTABLE_AUX_TRACE_OFFSET;
201    let b_chiplets = trace::CHIPLETS_BUS_AUX_TRACE_OFFSET;
202    let v_wiring = trace::ACE_CHIPLET_WIRING_BUS_OFFSET;
203
204    // Public input layout offsets.
205    // [0..4] program hash, [4..20] stack inputs, [20..36] stack outputs, [36..40] transcript state
206    let pv_program_hash = super::PV_PROGRAM_HASH;
207    let pv_transcript_state = super::PV_TRANSCRIPT_STATE;
208
209    // Bus message constants.
210    let log_precompile_label = Felt::from_u8(trace::LOG_PRECOMPILE_LABEL);
211
212    // ph_msg = encode([0, ph[0], ph[1], ph[2], ph[3], 0, 0])
213    // Matches program_hash_message() in lib.rs.
214    let ph_msg = vec![
215        Constant(Felt::ZERO),             // parent_id = 0
216        PublicInput(pv_program_hash),     // hash[0]
217        PublicInput(pv_program_hash + 1), // hash[1]
218        PublicInput(pv_program_hash + 2), // hash[2]
219        PublicInput(pv_program_hash + 3), // hash[3]
220        Constant(Felt::ZERO),             // is_first_child = false
221        Constant(Felt::ZERO),             // is_loop_body = false
222    ];
223
224    // default_msg = encode([LOG_PRECOMPILE_LABEL, 0, 0, 0, 0])
225    // Matches transcript_message(challenges, PrecompileTranscriptState::default()).
226    let default_msg = vec![
227        Constant(log_precompile_label),
228        Constant(Felt::ZERO),
229        Constant(Felt::ZERO),
230        Constant(Felt::ZERO),
231        Constant(Felt::ZERO),
232    ];
233
234    // final_msg = encode([LOG_PRECOMPILE_LABEL, ts[0], ts[1], ts[2], ts[3]])
235    // Matches transcript_message(challenges, pc_transcript_state).
236    let final_msg = vec![
237        Constant(log_precompile_label),
238        PublicInput(pv_transcript_state),
239        PublicInput(pv_transcript_state + 1),
240        PublicInput(pv_transcript_state + 2),
241        PublicInput(pv_transcript_state + 3),
242    ];
243
244    // product_check: product(numerator) - product(denominator) = 0
245    // sum_check:     sum(sum_columns) = 0
246    ReducedAuxBatchConfig {
247        numerator: vec![
248            BusBoundary(p1),
249            BusBoundary(p2),
250            BusBoundary(p3),
251            BusBoundary(s_aux),
252            BusBoundary(b_hash_kernel),
253            BusBoundary(b_chiplets),
254            Message(ph_msg),
255            Message(default_msg),
256        ],
257        denominator: vec![Message(final_msg), Vlpi(0)],
258        sum_columns: vec![b_range, v_wiring],
259    }
260}
261
262// CONVENIENCE FUNCTION
263// ================================================================================================
264
265/// Build a batched ACE circuit for the provided AIR.
266///
267/// This is the highest-level entry point for building the ACE circuit for Miden VM AIR.
268/// It builds the constraint-evaluation DAG, extends it with the auxiliary trace
269/// boundary checks and emits the off-VM circuit representation.
270///
271/// The output circuit checks:
272///   `constraint_check + gamma * product_check + gamma^2 * sum_check = 0`
273pub fn build_batched_ace_circuit<A, EF>(
274    air: &A,
275    config: AceConfig,
276    batch_config: &ReducedAuxBatchConfig,
277) -> Result<AceCircuit<EF>, AceError>
278where
279    A: LiftedAir<Felt, EF>,
280    EF: ExtensionField<Felt>,
281    SymbolicExpressionExt<Felt, EF>: Algebra<EF>,
282{
283    let artifacts = build_ace_dag_for_air::<A, Felt, EF>(air, config)?;
284    let batched_dag = batch_reduced_aux_values(artifacts.dag, batch_config);
285    miden_ace_codegen::emit_circuit(&batched_dag, artifacts.layout)
286}