miden_processor/test_utils/
test_host.rs1use alloc::{
2 collections::BTreeMap,
3 string::{String, ToString},
4 sync::Arc,
5 vec::Vec,
6};
7
8use miden_core::{Felt, operations::DebugOptions};
9use miden_debug_types::{
10 DefaultSourceManager, Location, SourceFile, SourceManager, SourceManagerSync, SourceSpan,
11};
12
13use crate::{
14 BaseHost, DebugError, DebugHandler, MastForestStore, MemMastForestStore, ProcessorState,
15 SyncHost, TraceError, Word, advice::AdviceMutation, event::EventError, mast::MastForest,
16};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct ProcessorStateSnapshot {
21 clk: u32,
22 ctx: u32,
23 stack_state: Vec<Felt>,
24 stack_words: [Word; 4],
25 mem_state: Vec<(crate::MemoryAddress, Felt)>,
26}
27
28impl From<&ProcessorState<'_>> for ProcessorStateSnapshot {
29 fn from(state: &ProcessorState) -> Self {
30 ProcessorStateSnapshot {
31 clk: state.clock().into(),
32 ctx: state.ctx().into(),
33 stack_state: state.get_stack_state(),
34 stack_words: [
35 state.get_stack_word(0),
36 state.get_stack_word(4),
37 state.get_stack_word(8),
38 state.get_stack_word(12),
39 ],
40 mem_state: state.get_mem_state(state.ctx()),
41 }
42 }
43}
44
45#[derive(Default, Debug, Clone)]
47pub struct TraceCollector {
48 trace_counts: BTreeMap<u32, u32>,
50 execution_order: Vec<(u32, u64)>,
52}
53
54impl TraceCollector {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn get_trace_count(&self, trace_id: u32) -> u32 {
62 self.trace_counts.get(&trace_id).copied().unwrap_or(0)
63 }
64
65 pub fn get_execution_order(&self) -> &[(u32, u64)] {
67 &self.execution_order
68 }
69}
70
71impl DebugHandler for TraceCollector {
72 fn on_trace(&mut self, process: &ProcessorState, trace_id: u32) -> Result<(), TraceError> {
73 *self.trace_counts.entry(trace_id).or_insert(0) += 1;
75
76 self.execution_order.push((trace_id, process.clock().into()));
78
79 Ok(())
80 }
81}
82
83#[derive(Debug, Clone)]
86pub struct TestHost<S: SourceManager = DefaultSourceManager> {
87 trace_collector: TraceCollector,
89
90 pub event_handler: Vec<u32>,
92
93 pub debug_handler: Vec<String>,
95
96 snapshots: BTreeMap<u32, Vec<ProcessorStateSnapshot>>,
98
99 store: MemMastForestStore,
101
102 pub source_manager: Arc<S>,
104}
105
106impl TestHost {
107 pub fn new() -> Self {
109 Self {
110 trace_collector: TraceCollector::new(),
111 event_handler: Vec::new(),
112 debug_handler: Vec::new(),
113 snapshots: BTreeMap::new(),
114 store: MemMastForestStore::default(),
115 source_manager: Arc::new(DefaultSourceManager::default()),
116 }
117 }
118
119 pub fn with_kernel_forest(kernel_forest: Arc<MastForest>) -> Self {
121 let mut store = MemMastForestStore::default();
122 store.insert(kernel_forest.clone());
123 Self {
124 trace_collector: TraceCollector::new(),
125 event_handler: Vec::new(),
126 debug_handler: Vec::new(),
127 snapshots: BTreeMap::new(),
128 store,
129 source_manager: Arc::new(DefaultSourceManager::default()),
130 }
131 }
132
133 pub fn get_trace_count(&self, trace_id: u32) -> u32 {
135 self.trace_collector.get_trace_count(trace_id)
136 }
137
138 pub fn get_execution_order(&self) -> &[(u32, u64)] {
140 self.trace_collector.get_execution_order()
141 }
142
143 pub fn snapshots(&self) -> &BTreeMap<u32, Vec<ProcessorStateSnapshot>> {
145 &self.snapshots
146 }
147}
148
149impl Default for TestHost {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl<S> BaseHost for TestHost<S>
156where
157 S: SourceManagerSync,
158{
159 fn get_label_and_source_file(
160 &self,
161 location: &Location,
162 ) -> (SourceSpan, Option<Arc<SourceFile>>) {
163 let maybe_file = self.source_manager.get_by_uri(location.uri());
164 let span = self.source_manager.location_to_span(location.clone()).unwrap_or_default();
165 (span, maybe_file)
166 }
167
168 fn on_debug(
169 &mut self,
170 _process: &ProcessorState,
171 options: &DebugOptions,
172 ) -> Result<(), DebugError> {
173 self.debug_handler.push(options.to_string());
174 Ok(())
175 }
176
177 fn on_trace(&mut self, process: &ProcessorState, trace_id: u32) -> Result<(), TraceError> {
178 self.trace_collector.on_trace(process, trace_id)?;
180
181 let snapshot = ProcessorStateSnapshot::from(process);
183 self.snapshots.entry(trace_id).or_default().push(snapshot);
184
185 Ok(())
186 }
187}
188
189impl<S> SyncHost for TestHost<S>
190where
191 S: SourceManagerSync,
192{
193 fn get_mast_forest(&self, node_digest: &Word) -> Option<Arc<MastForest>> {
194 self.store.get(node_digest)
195 }
196
197 fn on_event(&mut self, process: &ProcessorState) -> Result<Vec<AdviceMutation>, EventError> {
198 let event_id: u32 = process.get_stack_item(0).as_canonical_u64().try_into().unwrap();
199 self.event_handler.push(event_id);
200 Ok(Vec::new())
201 }
202}