Skip to main content

miden_processor/test_utils/
test_host.rs

1use alloc::{
2    collections::BTreeMap,
3    string::{String, ToString},
4    sync::Arc,
5    vec::Vec,
6};
7
8use miden_core::{Felt, operations::DebugOptions};
9use miden_debug_types::{
10    DefaultSourceManager, Location, SourceFile, SourceManager, SourceManagerSync, SourceSpan,
11};
12
13use crate::{
14    BaseHost, DebugError, DebugHandler, MastForestStore, MemMastForestStore, ProcessorState,
15    SyncHost, TraceError, Word, advice::AdviceMutation, event::EventError, mast::MastForest,
16};
17
18/// A snapshot of the processor state for consistency checking between processors.
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct ProcessorStateSnapshot {
21    clk: u32,
22    ctx: u32,
23    stack_state: Vec<Felt>,
24    stack_words: [Word; 4],
25    mem_state: Vec<(crate::MemoryAddress, Felt)>,
26}
27
28impl From<&ProcessorState<'_>> for ProcessorStateSnapshot {
29    fn from(state: &ProcessorState) -> Self {
30        ProcessorStateSnapshot {
31            clk: state.clock().into(),
32            ctx: state.ctx().into(),
33            stack_state: state.get_stack_state(),
34            stack_words: [
35                state.get_stack_word(0),
36                state.get_stack_word(4),
37                state.get_stack_word(8),
38                state.get_stack_word(12),
39            ],
40            mem_state: state.get_mem_state(state.ctx()),
41        }
42    }
43}
44
45/// A debug handler that collects and counts trace events from decorators.
46#[derive(Default, Debug, Clone)]
47pub struct TraceCollector {
48    /// Counts of each trace ID that has been emitted
49    trace_counts: BTreeMap<u32, u32>,
50    /// Execution order of trace events with their clock cycles
51    execution_order: Vec<(u32, u64)>,
52}
53
54impl TraceCollector {
55    /// Creates a new empty trace collector.
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Gets the count of executions for a specific trace ID.
61    pub fn get_trace_count(&self, trace_id: u32) -> u32 {
62        self.trace_counts.get(&trace_id).copied().unwrap_or(0)
63    }
64
65    /// Gets the execution order as a reference.
66    pub fn get_execution_order(&self) -> &[(u32, u64)] {
67        &self.execution_order
68    }
69}
70
71impl DebugHandler for TraceCollector {
72    fn on_trace(&mut self, process: &ProcessorState, trace_id: u32) -> Result<(), TraceError> {
73        // Count the trace event
74        *self.trace_counts.entry(trace_id).or_insert(0) += 1;
75
76        // Record the execution order with clock cycle
77        self.execution_order.push((trace_id, process.clock().into()));
78
79        Ok(())
80    }
81}
82
83/// A unified testing host that combines trace collection, event handling,
84/// debug handling, and process state consistency checking.
85#[derive(Debug, Clone)]
86pub struct TestHost<S: SourceManager = DefaultSourceManager> {
87    /// Trace collection functionality (counts and execution order)
88    trace_collector: TraceCollector,
89
90    /// List of event IDs that have been received
91    pub event_handler: Vec<u32>,
92
93    /// List of debug command strings that have been received
94    pub debug_handler: Vec<String>,
95
96    /// Process state snapshots for consistency checking
97    snapshots: BTreeMap<u32, Vec<ProcessorStateSnapshot>>,
98
99    /// MAST forest store for external node resolution
100    store: MemMastForestStore,
101
102    /// Source manager for debugging information
103    pub source_manager: Arc<S>,
104}
105
106impl TestHost {
107    /// Creates a new TestHost with minimal functionality for basic testing.
108    pub fn new() -> Self {
109        Self {
110            trace_collector: TraceCollector::new(),
111            event_handler: Vec::new(),
112            debug_handler: Vec::new(),
113            snapshots: BTreeMap::new(),
114            store: MemMastForestStore::default(),
115            source_manager: Arc::new(DefaultSourceManager::default()),
116        }
117    }
118
119    /// Creates a new TestHost with a kernel forest for full consistency testing.
120    pub fn with_kernel_forest(kernel_forest: Arc<MastForest>) -> Self {
121        let mut store = MemMastForestStore::default();
122        store.insert(kernel_forest.clone());
123        Self {
124            trace_collector: TraceCollector::new(),
125            event_handler: Vec::new(),
126            debug_handler: Vec::new(),
127            snapshots: BTreeMap::new(),
128            store,
129            source_manager: Arc::new(DefaultSourceManager::default()),
130        }
131    }
132
133    /// Gets the count of executions for a specific trace ID.
134    pub fn get_trace_count(&self, trace_id: u32) -> u32 {
135        self.trace_collector.get_trace_count(trace_id)
136    }
137
138    /// Gets the execution order as a reference (with clock cycles).
139    pub fn get_execution_order(&self) -> &[(u32, u64)] {
140        self.trace_collector.get_execution_order()
141    }
142
143    /// Gets mutable access to all snapshots.
144    pub fn snapshots(&self) -> &BTreeMap<u32, Vec<ProcessorStateSnapshot>> {
145        &self.snapshots
146    }
147}
148
149impl Default for TestHost {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl<S> BaseHost for TestHost<S>
156where
157    S: SourceManagerSync,
158{
159    fn get_label_and_source_file(
160        &self,
161        location: &Location,
162    ) -> (SourceSpan, Option<Arc<SourceFile>>) {
163        let maybe_file = self.source_manager.get_by_uri(location.uri());
164        let span = self.source_manager.location_to_span(location.clone()).unwrap_or_default();
165        (span, maybe_file)
166    }
167
168    fn on_debug(
169        &mut self,
170        _process: &ProcessorState,
171        options: &DebugOptions,
172    ) -> Result<(), DebugError> {
173        self.debug_handler.push(options.to_string());
174        Ok(())
175    }
176
177    fn on_trace(&mut self, process: &ProcessorState, trace_id: u32) -> Result<(), TraceError> {
178        // Forward to trace collector for counting and execution order tracking
179        self.trace_collector.on_trace(process, trace_id)?;
180
181        // Also collect process state snapshot for consistency checking
182        let snapshot = ProcessorStateSnapshot::from(process);
183        self.snapshots.entry(trace_id).or_default().push(snapshot);
184
185        Ok(())
186    }
187}
188
189impl<S> SyncHost for TestHost<S>
190where
191    S: SourceManagerSync,
192{
193    fn get_mast_forest(&self, node_digest: &Word) -> Option<Arc<MastForest>> {
194        self.store.get(node_digest)
195    }
196
197    fn on_event(&mut self, process: &ProcessorState) -> Result<Vec<AdviceMutation>, EventError> {
198        let event_id: u32 = process.get_stack_item(0).as_canonical_u64().try_into().unwrap();
199        self.event_handler.push(event_id);
200        Ok(Vec::new())
201    }
202}