1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2
3use itertools::Itertools;
4use miden_air::{
5 Felt,
6 trace::{
7 CLK_COL_IDX, DECODER_TRACE_OFFSET, DECODER_TRACE_WIDTH, MIN_TRACE_LEN, MainTrace, RowIndex,
8 STACK_TRACE_OFFSET, STACK_TRACE_WIDTH, SYS_TRACE_WIDTH,
9 decoder::{
10 HASHER_STATE_OFFSET, NUM_HASHER_COLUMNS, NUM_OP_BITS, OP_BITS_EXTRA_COLS_OFFSET,
11 OP_BITS_OFFSET,
12 },
13 stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, STACK_TOP_OFFSET},
14 },
15};
16use miden_core::{
17 ONE, Word, ZERO,
18 field::batch_inversion_allow_zeros,
19 mast::{MastForest, MastNode},
20 operations::opcodes,
21 program::{Kernel, MIN_STACK_DEPTH},
22};
23use rayon::prelude::*;
24use tracing::instrument;
25
26use crate::{
27 ContextId, ExecutionError,
28 continuation_stack::ContinuationStack,
29 errors::MapExecErrNoCtx,
30 trace::{
31 AuxTraceBuilders, ChipletsLengths, ExecutionTrace, TraceBuildInputs, TraceLenSummary,
32 parallel::{processor::ReplayProcessor, tracer::CoreTraceGenerationTracer},
33 range::RangeChecker,
34 utils::RowMajorTraceWriter,
35 },
36};
37
38pub const CORE_TRACE_WIDTH: usize = SYS_TRACE_WIDTH + DECODER_TRACE_WIDTH + STACK_TRACE_WIDTH;
39
40const MAX_TRACE_LEN: usize = 1 << 29;
45
46pub(crate) mod core_trace_fragment;
47
48mod processor;
49mod tracer;
50
51use super::{
52 chiplets::Chiplets,
53 decoder::AuxTraceBuilder as DecoderAuxTraceBuilder,
54 execution_tracer::TraceGenerationContext,
55 stack::AuxTraceBuilder as StackAuxTraceBuilder,
56 trace_state::{
57 AceReplay, BitwiseOp, BitwiseReplay, CoreTraceFragmentContext, CoreTraceState,
58 ExecutionReplay, HasherOp, HasherRequestReplay, KernelReplay, MemoryWritesReplay,
59 RangeCheckerReplay,
60 },
61};
62
63#[cfg(test)]
64mod tests;
65
66#[instrument(name = "build_trace", skip_all)]
87pub fn build_trace(inputs: TraceBuildInputs) -> Result<ExecutionTrace, ExecutionError> {
88 build_trace_with_max_len(inputs, MAX_TRACE_LEN)
89}
90
91pub fn build_trace_with_max_len(
96 inputs: TraceBuildInputs,
97 max_trace_len: usize,
98) -> Result<ExecutionTrace, ExecutionError> {
99 let TraceBuildInputs {
100 trace_output,
101 trace_generation_context,
102 program_info,
103 } = inputs;
104
105 if !trace_output.has_matching_precompile_requests_digest() {
106 return Err(ExecutionError::Internal(
107 "trace inputs do not match deferred precompile requests",
108 ));
109 }
110
111 let TraceGenerationContext {
112 core_trace_contexts,
113 range_checker_replay,
114 memory_writes,
115 bitwise_replay: bitwise,
116 kernel_replay,
117 hasher_for_chiplet,
118 ace_replay,
119 fragment_size,
120 } = trace_generation_context;
121
122 let total_core_trace_rows = core_trace_contexts
129 .len()
130 .checked_mul(fragment_size)
131 .and_then(|n| n.checked_add(1))
132 .ok_or(ExecutionError::TraceLenExceeded(max_trace_len))?;
133 if total_core_trace_rows > max_trace_len {
134 return Err(ExecutionError::TraceLenExceeded(max_trace_len));
135 }
136
137 if core_trace_contexts.is_empty() {
138 return Err(ExecutionError::Internal(
139 "no trace fragments provided in the trace generation context",
140 ));
141 }
142
143 let chiplets = initialize_chiplets(
144 program_info.kernel().clone(),
145 &core_trace_contexts,
146 memory_writes,
147 bitwise,
148 kernel_replay,
149 hasher_for_chiplet,
150 ace_replay,
151 max_trace_len,
152 )?;
153
154 let range_checker = initialize_range_checker(range_checker_replay, &chiplets);
155
156 let mut core_trace_data = generate_core_trace_row_major(
157 core_trace_contexts,
158 program_info.kernel().clone(),
159 fragment_size,
160 )?;
161
162 let core_trace_len = core_trace_data.len() / CORE_TRACE_WIDTH;
163
164 let range_table_len = range_checker.get_number_range_checker_rows();
166
167 let trace_len_summary =
168 TraceLenSummary::new(core_trace_len, range_table_len, ChipletsLengths::new(&chiplets));
169
170 let main_trace_len =
172 compute_main_trace_length(core_trace_len, range_table_len, chiplets.trace_len());
173
174 let ((range_checker_trace, chiplets_trace), ()) = rayon::join(
175 || {
176 rayon::join(
177 || range_checker.into_trace_with_table(range_table_len, main_trace_len),
178 || chiplets.into_trace(main_trace_len),
179 )
180 },
181 || pad_core_row_major(&mut core_trace_data, main_trace_len),
182 );
183
184 let main_trace = {
186 let last_program_row = RowIndex::from((core_trace_len as u32).saturating_sub(1));
187 MainTrace::from_parts(
188 core_trace_data,
189 chiplets_trace.trace,
190 range_checker_trace.trace,
191 main_trace_len,
192 last_program_row,
193 )
194 };
195
196 let aux_trace_builders = AuxTraceBuilders {
198 decoder: DecoderAuxTraceBuilder::default(),
199 range: range_checker_trace.aux_builder,
200 chiplets: chiplets_trace.aux_builder,
201 stack: StackAuxTraceBuilder,
202 };
203
204 Ok(ExecutionTrace::new_from_parts(
205 program_info,
206 trace_output,
207 main_trace,
208 aux_trace_builders,
209 trace_len_summary,
210 ))
211}
212
213fn compute_main_trace_length(
217 core_trace_len: usize,
218 range_table_len: usize,
219 chiplets_trace_len: usize,
220) -> usize {
221 let max_len = range_table_len.max(core_trace_len).max(chiplets_trace_len);
223
224 let trace_len = max_len.next_power_of_two();
226 core::cmp::max(trace_len, MIN_TRACE_LEN)
227}
228
229fn generate_core_trace_row_major(
231 core_trace_contexts: Vec<CoreTraceFragmentContext>,
232 kernel: Kernel,
233 fragment_size: usize,
234) -> Result<Vec<Felt>, ExecutionError> {
235 let num_fragments = core_trace_contexts.len();
236 let total_allocated_rows = num_fragments * fragment_size;
237
238 let mut core_trace_data: Vec<Felt> = vec![ZERO; total_allocated_rows * CORE_TRACE_WIDTH];
239
240 let first_stack_top = if let Some(first_context) = core_trace_contexts.first() {
242 first_context.state.stack.stack_top.to_vec()
243 } else {
244 vec![ZERO; MIN_STACK_DEPTH]
245 };
246
247 let writers: Vec<RowMajorTraceWriter<'_, Felt>> = core_trace_data
248 .chunks_exact_mut(fragment_size * CORE_TRACE_WIDTH)
249 .map(|chunk| RowMajorTraceWriter::new(chunk, CORE_TRACE_WIDTH))
250 .collect();
251
252 let fragment_results: Result<Vec<_>, ExecutionError> = core_trace_contexts
254 .into_par_iter()
255 .zip(writers.into_par_iter())
256 .map(|(trace_state, writer)| {
257 let (mut processor, mut tracer, mut continuation_stack, mut current_forest) =
258 split_trace_fragment_context(trace_state, writer, fragment_size);
259
260 processor.execute(
261 &mut continuation_stack,
262 &mut current_forest,
263 &kernel,
264 &mut tracer,
265 )?;
266
267 tracer.into_final_state()
268 })
269 .collect();
270 let fragment_results = fragment_results?;
271
272 let mut stack_rows = Vec::new();
273 let mut system_rows = Vec::new();
274 let mut total_core_trace_rows = 0;
275
276 for final_state in fragment_results {
277 stack_rows.push(final_state.last_stack_cols);
278 system_rows.push(final_state.last_system_cols);
279 total_core_trace_rows += final_state.num_rows_written;
280 }
281
282 fixup_stack_and_system_rows(
284 &mut core_trace_data,
285 fragment_size,
286 &stack_rows,
287 &system_rows,
288 &first_stack_top,
289 );
290
291 {
295 let h0_col_offset = STACK_TRACE_OFFSET + H0_COL_IDX;
296 let w = CORE_TRACE_WIDTH;
297 core_trace_data[..total_core_trace_rows * w]
298 .par_chunks_mut(fragment_size * w)
299 .for_each(|fragment_chunk| {
300 let num_rows = fragment_chunk.len() / w;
301 let mut h0_vals: Vec<Felt> =
302 (0..num_rows).map(|r| fragment_chunk[r * w + h0_col_offset]).collect();
303 batch_inversion_allow_zeros(&mut h0_vals);
304 for (r, &val) in h0_vals.iter().enumerate() {
305 fragment_chunk[r * w + h0_col_offset] = val;
306 }
307 });
308 }
309
310 core_trace_data.truncate(total_core_trace_rows * CORE_TRACE_WIDTH);
312
313 push_halt_opcode_row(
314 &mut core_trace_data,
315 total_core_trace_rows,
316 system_rows.last().ok_or(ExecutionError::Internal(
317 "no trace fragments provided in the trace generation context",
318 ))?,
319 stack_rows.last().ok_or(ExecutionError::Internal(
320 "no trace fragments provided in the trace generation context",
321 ))?,
322 );
323
324 Ok(core_trace_data)
325}
326
327fn fixup_stack_and_system_rows(
334 core_trace_data: &mut [Felt],
335 fragment_size: usize,
336 stack_rows: &[[Felt; STACK_TRACE_WIDTH]],
337 system_rows: &[[Felt; SYS_TRACE_WIDTH]],
338 first_stack_top: &[Felt],
339) {
340 const MIN_STACK_DEPTH_FELT: Felt = Felt::new(MIN_STACK_DEPTH as u64);
341 let w = CORE_TRACE_WIDTH;
342
343 {
344 let row = &mut core_trace_data[..w];
345
346 for (stack_col_idx, &value) in first_stack_top.iter().rev().enumerate() {
348 row[STACK_TRACE_OFFSET + STACK_TOP_OFFSET + stack_col_idx] = value;
349 }
350
351 row[STACK_TRACE_OFFSET + B0_COL_IDX] = MIN_STACK_DEPTH_FELT;
352 row[STACK_TRACE_OFFSET + B1_COL_IDX] = ZERO;
353 row[STACK_TRACE_OFFSET + H0_COL_IDX] = ZERO;
354 }
355
356 let total_rows = core_trace_data.len() / w;
357 let num_fragments = total_rows / fragment_size;
358
359 for frag_idx in 1..num_fragments {
360 let row_idx = frag_idx * fragment_size;
361 let row_start = row_idx * w;
362 let system_row = &system_rows[frag_idx - 1];
363 let stack_row = &stack_rows[frag_idx - 1];
364
365 core_trace_data[row_start..row_start + SYS_TRACE_WIDTH].copy_from_slice(system_row);
366
367 let stack_start = row_start + STACK_TRACE_OFFSET;
368 core_trace_data[stack_start..stack_start + STACK_TRACE_WIDTH].copy_from_slice(stack_row);
369 }
370}
371
372fn push_halt_opcode_row(
377 core_trace_data: &mut Vec<Felt>,
378 num_rows_before: usize,
379 last_system_state: &[Felt; SYS_TRACE_WIDTH],
380 last_stack_state: &[Felt; STACK_TRACE_WIDTH],
381) {
382 let w = CORE_TRACE_WIDTH;
383 let mut row = [ZERO; CORE_TRACE_WIDTH];
384
385 row[..SYS_TRACE_WIDTH].copy_from_slice(last_system_state);
388
389 row[STACK_TRACE_OFFSET..STACK_TRACE_OFFSET + STACK_TRACE_WIDTH]
392 .copy_from_slice(last_stack_state);
393
394 let halt_opcode = opcodes::HALT;
396 for bit_idx in 0..NUM_OP_BITS {
397 row[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + bit_idx] =
398 Felt::from_u8((halt_opcode >> bit_idx) & 1);
399 }
400
401 if num_rows_before > 0 {
405 let last_row_start = (num_rows_before - 1) * w;
406 for hasher_col_idx in 0..4 {
408 let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + hasher_col_idx;
409 row[col_idx] = core_trace_data[last_row_start + col_idx];
410 }
411 }
412
413 row[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1] = ONE;
417
418 core_trace_data.extend_from_slice(&row);
419}
420
421fn initialize_range_checker(
427 range_checker_replay: RangeCheckerReplay,
428 chiplets: &Chiplets,
429) -> RangeChecker {
430 let mut range_checker = RangeChecker::new();
431
432 for (clk, values) in range_checker_replay.into_iter() {
434 range_checker.add_range_checks(clk, &values);
435 }
436
437 chiplets.append_range_checks(&mut range_checker);
439
440 range_checker
441}
442
443fn initialize_chiplets(
446 kernel: Kernel,
447 core_trace_contexts: &[CoreTraceFragmentContext],
448 memory_writes: MemoryWritesReplay,
449 bitwise: BitwiseReplay,
450 kernel_replay: KernelReplay,
451 hasher_for_chiplet: HasherRequestReplay,
452 ace_replay: AceReplay,
453 max_trace_len: usize,
454) -> Result<Chiplets, ExecutionError> {
455 let check_chiplets_trace_len = |chiplets: &Chiplets| -> Result<(), ExecutionError> {
456 if chiplets.trace_len() > max_trace_len {
457 return Err(ExecutionError::TraceLenExceeded(max_trace_len));
458 }
459 Ok(())
460 };
461
462 let mut chiplets = Chiplets::new(kernel);
463
464 for hasher_op in hasher_for_chiplet.into_iter() {
466 match hasher_op {
467 HasherOp::Permute(input_state) => {
468 let _ = chiplets.hasher.permute(input_state);
469 check_chiplets_trace_len(&chiplets)?;
470 },
471 HasherOp::HashControlBlock((h1, h2, domain, expected_hash)) => {
472 let _ = chiplets.hasher.hash_control_block(h1, h2, domain, expected_hash);
473 check_chiplets_trace_len(&chiplets)?;
474 },
475 HasherOp::HashBasicBlock((forest, node_id, expected_hash)) => {
476 let node = forest
477 .get_node_by_id(node_id)
478 .ok_or(ExecutionError::Internal("invalid node ID in hasher replay"))?;
479 let MastNode::Block(basic_block_node) = node else {
480 return Err(ExecutionError::Internal(
481 "expected basic block node in hasher replay",
482 ));
483 };
484 let op_batches = basic_block_node.op_batches();
485 let _ = chiplets.hasher.hash_basic_block(op_batches, expected_hash);
486 check_chiplets_trace_len(&chiplets)?;
487 },
488 HasherOp::BuildMerkleRoot((value, path, index)) => {
489 let _ = chiplets.hasher.build_merkle_root(value, &path, index);
490 check_chiplets_trace_len(&chiplets)?;
491 },
492 HasherOp::UpdateMerkleRoot((old_value, new_value, path, index)) => {
493 chiplets.hasher.update_merkle_root(old_value, new_value, &path, index);
494 check_chiplets_trace_len(&chiplets)?;
495 },
496 }
497 }
498
499 for (bitwise_op, a, b) in bitwise {
501 match bitwise_op {
502 BitwiseOp::U32And => {
503 chiplets.bitwise.u32and(a, b).map_exec_err_no_ctx()?;
504 check_chiplets_trace_len(&chiplets)?;
505 },
506 BitwiseOp::U32Xor => {
507 chiplets.bitwise.u32xor(a, b).map_exec_err_no_ctx()?;
508 check_chiplets_trace_len(&chiplets)?;
509 },
510 }
511 }
512
513 {
519 let elements_written: Box<dyn Iterator<Item = MemoryAccess>> =
520 Box::new(memory_writes.iter_elements_written().map(|(element, addr, ctx, clk)| {
521 MemoryAccess::WriteElement(*addr, *element, *ctx, *clk)
522 }));
523 let words_written: Box<dyn Iterator<Item = MemoryAccess>> = Box::new(
524 memory_writes
525 .iter_words_written()
526 .map(|(word, addr, ctx, clk)| MemoryAccess::WriteWord(*addr, *word, *ctx, *clk)),
527 );
528 let elements_read: Box<dyn Iterator<Item = MemoryAccess>> =
529 Box::new(core_trace_contexts.iter().flat_map(|ctx| {
530 ctx.replay
531 .memory_reads
532 .iter_read_elements()
533 .map(|(_, addr, ctx, clk)| MemoryAccess::ReadElement(addr, ctx, clk))
534 }));
535 let words_read: Box<dyn Iterator<Item = MemoryAccess>> =
536 Box::new(core_trace_contexts.iter().flat_map(|ctx| {
537 ctx.replay
538 .memory_reads
539 .iter_read_words()
540 .map(|(_, addr, ctx, clk)| MemoryAccess::ReadWord(addr, ctx, clk))
541 }));
542
543 [elements_written, words_written, elements_read, words_read]
544 .into_iter()
545 .kmerge_by(|a, b| a.clk() < b.clk())
546 .try_for_each(|mem_access| {
547 match mem_access {
548 MemoryAccess::ReadElement(addr, ctx, clk) => chiplets
549 .memory
550 .read(ctx, addr, clk)
551 .map(|_| ())
552 .map_err(ExecutionError::MemoryErrorNoCtx)?,
553 MemoryAccess::WriteElement(addr, element, ctx, clk) => chiplets
554 .memory
555 .write(ctx, addr, clk, element)
556 .map_err(ExecutionError::MemoryErrorNoCtx)?,
557 MemoryAccess::ReadWord(addr, ctx, clk) => chiplets
558 .memory
559 .read_word(ctx, addr, clk)
560 .map(|_| ())
561 .map_err(ExecutionError::MemoryErrorNoCtx)?,
562 MemoryAccess::WriteWord(addr, word, ctx, clk) => chiplets
563 .memory
564 .write_word(ctx, addr, clk, word)
565 .map_err(ExecutionError::MemoryErrorNoCtx)?,
566 }
567 check_chiplets_trace_len(&chiplets)
568 })?;
569
570 enum MemoryAccess {
571 ReadElement(Felt, ContextId, RowIndex),
572 WriteElement(Felt, Felt, ContextId, RowIndex),
573 ReadWord(Felt, ContextId, RowIndex),
574 WriteWord(Felt, Word, ContextId, RowIndex),
575 }
576
577 impl MemoryAccess {
578 fn clk(&self) -> RowIndex {
579 match self {
580 MemoryAccess::ReadElement(_, _, clk) => *clk,
581 MemoryAccess::WriteElement(_, _, _, clk) => *clk,
582 MemoryAccess::ReadWord(_, _, clk) => *clk,
583 MemoryAccess::WriteWord(_, _, _, clk) => *clk,
584 }
585 }
586 }
587 }
588
589 for (clk, circuit_eval) in ace_replay.into_iter() {
591 chiplets.ace.add_circuit_evaluation(clk, circuit_eval);
592 check_chiplets_trace_len(&chiplets)?;
593 }
594
595 for proc_hash in kernel_replay.into_iter() {
597 chiplets.kernel_rom.access_proc(proc_hash).map_exec_err_no_ctx()?;
598 check_chiplets_trace_len(&chiplets)?;
599 }
600
601 Ok(chiplets)
602}
603
604fn pad_core_row_major(core_trace_data: &mut Vec<Felt>, main_trace_len: usize) {
606 let w = CORE_TRACE_WIDTH;
607 let total_program_rows = core_trace_data.len() / w;
608 assert!(total_program_rows <= main_trace_len);
609 assert!(total_program_rows > 0);
610
611 let num_padding_rows = main_trace_len - total_program_rows;
612 if num_padding_rows == 0 {
613 return;
614 }
615 let last_row_start = (total_program_rows - 1) * w;
616
617 let mut template = [ZERO; CORE_TRACE_WIDTH];
621 let halt_opcode = opcodes::HALT;
623 for i in 0..NUM_OP_BITS {
624 let bit_value = Felt::from_u8((halt_opcode >> i) & 1);
625 template[DECODER_TRACE_OFFSET + OP_BITS_OFFSET + i] = bit_value;
626 }
627 for i in 0..NUM_HASHER_COLUMNS {
631 let col_idx = DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + i;
632 template[col_idx] = if i < 4 {
633 core_trace_data[last_row_start + col_idx]
637 } else {
638 ZERO
639 };
640 }
641
642 template[DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET + 1] = ONE;
646
647 for i in 0..STACK_TRACE_WIDTH {
652 let col_idx = STACK_TRACE_OFFSET + i;
653 template[col_idx] = core_trace_data[last_row_start + col_idx];
656 }
657
658 core_trace_data.reserve(num_padding_rows * w);
664 for idx in 0..num_padding_rows {
665 template[CLK_COL_IDX] = Felt::from_u32((total_program_rows + idx) as u32);
666 core_trace_data.extend_from_slice(&template);
667 }
668}
669
670fn split_trace_fragment_context<'a>(
673 fragment_context: CoreTraceFragmentContext,
674 writer: RowMajorTraceWriter<'a, Felt>,
675 fragment_size: usize,
676) -> (
677 ReplayProcessor,
678 CoreTraceGenerationTracer<'a>,
679 ContinuationStack,
680 Arc<MastForest>,
681) {
682 let CoreTraceFragmentContext {
683 state: CoreTraceState { system, decoder, stack },
684 replay:
685 ExecutionReplay {
686 block_stack: block_stack_replay,
687 execution_context: execution_context_replay,
688 stack_overflow: stack_overflow_replay,
689 memory_reads: memory_reads_replay,
690 advice: advice_replay,
691 hasher: hasher_response_replay,
692 block_address: block_address_replay,
693 mast_forest_resolution: mast_forest_resolution_replay,
694 },
695 continuation,
696 initial_mast_forest,
697 } = fragment_context;
698
699 let processor = ReplayProcessor::new(
700 system,
701 stack,
702 stack_overflow_replay,
703 execution_context_replay,
704 advice_replay,
705 memory_reads_replay,
706 hasher_response_replay,
707 mast_forest_resolution_replay,
708 fragment_size.into(),
709 );
710 let tracer =
711 CoreTraceGenerationTracer::new(writer, decoder, block_address_replay, block_stack_replay);
712
713 (processor, tracer, continuation, initial_mast_forest)
714}